Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its
| --- | --- | --- |
| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |
| `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` |
| `SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS` | Config BatchSpanProcessor.schedule_delay_millis if tracing is enabled | `500` |
| `SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE` | Config BatchSpanProcessor.max_export_batch_size if tracing is enabled | `64` |

## Storage & Caching

Expand Down
54 changes: 44 additions & 10 deletions docs/references/production_request_trace.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server.
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server.

You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.

Expand All @@ -22,7 +22,13 @@ This section explains how to configure the request tracing and export the trace

3. start your SGLang server with tracing enabled
```bash
python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 <other option>
# set env variables
export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500
export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64
# start the prefill and decode server
python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
# start the mini lb
python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
```

Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.
Expand All @@ -39,9 +45,9 @@ We have already inserted instrumentation points in the tokenizer and scheduler m

Every process involved in tracing during the initialization phase should execute:
```python
process_tracing_init(oltp_traces_endpoint, server_name)
process_tracing_init(otlp_traces_endpoint, server_name)
```
The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.
The otlp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.

Every thread involved in tracing during the initialization phase should execute:
```python
Expand Down Expand Up @@ -95,24 +101,52 @@ We have already inserted instrumentation points in the tokenizer and scheduler m
trace_set_proc_propagate_context(rid, req.trace_context)
```

5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to node thread via http
```python
trace_context = trace_get_remote_propagate_context(bootstrap_room_list)
headers = {"trace_context": trace_context}
session.post(url, headers=headers)
```
- receiver: Execute the following code after receiving the request via http
```python
trace_set_remote_propagate_context(request.headers['trace_context'])
```

## How to Extend the Tracing Framework to Support Complex Tracing Scenarios

The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.

The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure.

The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows:
The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows:
```
SglangTraceReqContext (req_id="req-123")
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
│ └── SglangTraceSliceContext (name="prefill") # cur slice
|
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
└── SglangTraceSliceContext (name="prefill") # cur slice
```

Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends.
Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list.

In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.

When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.


We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design.

```
bootstrap room span
├── router req root span
| └── router thread span
| └── slice span
├── prefill req root span
| ├── tokenizer thread span
| | └── slice span
| └── scheduler thread span
| └── slice span
└── decode req root span
├── tokenizer thread span
| └── slice span
└── scheduler thread span
└── slice span
```
25 changes: 25 additions & 0 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.tracing.trace import (
trace_event_batch,
trace_slice_batch,
trace_slice_end,
)
from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter

Expand Down Expand Up @@ -313,6 +318,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
)

req.add_latency(RequestStage.DECODE_PREPARE)
trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
Expand Down Expand Up @@ -528,6 +534,9 @@ def pop_preallocated(self) -> List[DecodeRequest]:
time.perf_counter()
)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
trace_slice_end(
RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
)

self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
Expand Down Expand Up @@ -776,8 +785,19 @@ def pop_transferred(self) -> List[Req]:
[decode_req.req], decode_req.req.return_logprob
)
self.tree_cache.cache_finished_req(decode_req.req)
trace_slice_end(
RequestStage.DECODE_QUICK_FINISH,
decode_req.req.rid,
thread_finish_flag=True,
)
else:
transferred_reqs.append(decode_req.req)
trace_slice_end(
RequestStage.DECODE_TRANSFERRED,
decode_req.req.rid,
auto_next_anon=True,
)

elif poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
Expand Down Expand Up @@ -823,6 +843,7 @@ def event_loop_normal_disagg_decode(self: Scheduler):
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None)
else:
Expand Down Expand Up @@ -872,6 +893,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
if prepare_mlp_sync_flag:
batch_, batch_result = self._prepare_idle_batch_and_run(
None, delay_process=True
Expand Down Expand Up @@ -954,6 +976,9 @@ def get_next_disagg_decode_batch_to_run(
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None

if ret:
attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
trace_event_batch("schedule", ret.reqs, attrs=attrs)
return ret

def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
Expand Down
19 changes: 19 additions & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NSATokenToKVPool,
SWAKVPool,
)
from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync

if TYPE_CHECKING:
Expand Down Expand Up @@ -198,6 +199,7 @@ def add(self, req: Req, num_kv_heads: int) -> None:
self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req)
trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)

def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
for req in reqs:
Expand Down Expand Up @@ -289,6 +291,10 @@ def pop_bootstrapped(
req.time_stats.wait_queue_entry_time = time.perf_counter()
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)

trace_slice_end(
RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
)

self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
Expand Down Expand Up @@ -316,6 +322,9 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if batch:
attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
trace_event_batch("schedule", batch.reqs, attrs=attrs)

if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
Expand Down Expand Up @@ -348,6 +357,9 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if batch:
attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
trace_event_batch("schedule", batch.reqs, attrs=attrs)

if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
Expand Down Expand Up @@ -423,6 +435,7 @@ def process_batch_result_disagg_prefill(
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
self.disagg_prefill_inflight_queue.append(req)
if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
req.output_topk_p = batch.spec_info.topk_p[i]
Expand Down Expand Up @@ -487,6 +500,9 @@ def process_batch_result_disagg_prefill(

if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
trace_slice(
RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
)

self.maybe_send_health_check_signal()

Expand Down Expand Up @@ -558,6 +574,9 @@ def process_disagg_prefill_inflight_queue(
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1
trace_slice(
RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
)

self.disagg_prefill_inflight_queue = undone_reqs

Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,13 @@ def __init__(self, **kwargs):

# Enable tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
thread_label = "Tokenizer"
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill Tokenizer"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode Tokenizer"
trace_set_thread_info(thread_label)

try:
self.loop = asyncio.get_running_loop()
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI):

# Init tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
trace_set_thread_info(thread_label)
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill" + thread_label
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode" + thread_label
trace_set_thread_info(thread_label)

# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ class Envs:
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS = EnvInt(500)
SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE = EnvInt(64)

# Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False)
Expand Down
33 changes: 30 additions & 3 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@
TokenizedGenerateReqInput,
WatchLoadUpdateReq,
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_batch import Req, RequestStage
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import (
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
PortArgs,
ServerArgs,
)
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_get_proc_propagate_context,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice_end,
trace_slice_start,
)
from sglang.srt.utils import (
bind_port,
configure_logger,
Expand Down Expand Up @@ -170,11 +178,22 @@ def send_control_message(self, obj):
def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj)

def dispatching_with_trace(self, req: Req):
if self.server_args.enable_trace:
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
req.trace_context = trace_get_proc_propagate_context(req.rid)

self.dispatching(req)

if self.server_args.enable_trace:
trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)

def init_dispatcher(self):
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.dispatching),
(TokenizedEmbeddingReqInput, self.dispatching),
(TokenizedGenerateReqInput, self.dispatching_with_trace),
(TokenizedEmbeddingReqInput, self.dispatching_with_trace),
(BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req),
]
Expand Down Expand Up @@ -487,6 +506,14 @@ def run_data_parallel_controller_process(
pipe_writer,
):
kill_itself_when_parent_died()
if server_args.enable_trace:
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
thread_label = "DP Controller"
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill DP Controller"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode DP Controller"
trace_set_thread_info(thread_label)
setproctitle.setproctitle("sglang::data_parallel_controller")
faulthandler.enable()
configure_logger(server_args)
Expand Down
16 changes: 14 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,20 +392,32 @@ def merge(self, other: MultimodalInputs):


class RequestStage(str, enum.Enum):
# prefill
# Tokenizer
TOKENIZE = "tokenize"
TOKENIZER_DISPATCH = "dispatch"

# DP controller
DC_DISPATCH = "dc_dispatch"

# common/non-disaggregation
PREFILL_WAITING = "prefill_waiting"
REQUEST_PROCESS = "request_process"
DECODE_LOOP = "decode_loop"
PREFILL_FORWARD = "prefill_forward"
PREFILL_CHUNKED_FORWARD = "chunked_prefill"

# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

# disaggregation decode
DECODE_PREPARE = "decode_prepare"
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"
DECODE_FAKE_OUTPUT = "fake_output"
DECODE_QUICK_FINISH = "quick_finish"


class Req:
Expand Down
Loading
Loading