Skip to content

Commit 95c81dd

Browse files
backed out of some weird changes that claude made
1 parent 0097130 commit 95c81dd

File tree

7 files changed

+90
-324
lines changed

7 files changed

+90
-324
lines changed

python/sglang/srt/layers/logits_processor.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -366,15 +366,13 @@ def forward(
366366
"""
367367
verification_hidden_states_to_store: Optional[torch.Tensor] = None
368368
if logits_metadata.verification_algorithm.is_toploc():
369+
logger.debug(
370+
f"Capturing TopLoc verification hidden states with shape {pruned_states.shape if pruned_states is not None else 'None'}"
371+
)
369372
verification_hidden_states_to_store = (
370373
pruned_states[sample_indices] if sample_indices else pruned_states
371374
)
372375

373-
# For TOPLOC verification algorithm, capture hidden states and generate proof
374-
verification_proof: Optional[list] = None
375-
if logits_metadata.verification_algorithm.is_toploc():
376-
verification_proof = self.generate_verification_proof(hidden_states)
377-
378376
if not logits_metadata.extend_return_logprob:
379377
# Decode mode or extend mode without return_logprob.
380378
return LogitsProcessorOutput(
@@ -573,21 +571,6 @@ def compute_temp_top_p_normalized_logprobs(
573571
else:
574572
return torch.nn.functional.log_softmax(last_logits, dim=-1)
575573

576-
def generate_verification_proof(self, hidden_states: torch.Tensor) -> list:
577-
"""Generate a verification proof from hidden states.
578-
579-
The proof is a fingerprint or hash-like representation of the hidden states.
580-
In this implementation, we use a simple mean of the hidden states as a proof,
581-
but more sophisticated methods could be implemented.
582-
583-
Args:
584-
hidden_states: The hidden states to generate proof from
585-
586-
Returns:
587-
A list representation of the proof
588-
"""
589-
return []
590-
591574

592575
@triton.jit
593576
def fused_softcap_kernel(

python/sglang/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
45

56
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -14,6 +15,8 @@
1415
ScheduleBatch,
1516
)
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class SchedulerOutputProcessorMixin:
1922
"""
@@ -120,6 +123,9 @@ def process_batch_result_prefill(
120123
and logits_output.verification_hidden_states is not None
121124
):
122125
# Process verification hidden states for the current request
126+
logger.debug(
127+
f"Processing verification hidden states for prefill in req {req.req_id}"
128+
)
123129
req.verification_proofs.append(
124130
create_toploc_proofs(
125131
logits_output.verification_hidden_states[
@@ -132,6 +138,9 @@ def process_batch_result_prefill(
132138
.clone()
133139
)
134140
)
141+
logger.debug(
142+
f"Added verification proof #{len(req.verification_proofs)} to req {req.req_id} (prefill)"
143+
)
135144

136145
if req.grammar is not None:
137146
req.grammar.accept_token(next_token_id)
@@ -270,11 +279,17 @@ def process_batch_result_decode(
270279
)
271280

272281
if logits_output.verification_hidden_states is not None:
282+
logger.debug(
283+
f"Processing verification hidden states for decode in req {req.req_id}"
284+
)
273285
req.verification_proofs.append(
274286
create_toploc_proofs(
275287
logits_output.verification_hidden_states[i].cpu().clone()
276288
)
277289
)
290+
logger.debug(
291+
f"Added verification proof #{len(req.verification_proofs)} to req {req.req_id} (decode)"
292+
)
278293

279294
if req.grammar is not None and batch.spec_algorithm.is_none():
280295
req.grammar.accept_token(next_token_id)
@@ -589,6 +604,9 @@ def stream_output_generation(
589604
if req.return_verification_proofs and hasattr(
590605
req, "verification_proofs"
591606
):
607+
logger.debug(
608+
f"Collecting verification proofs for req {req.req_id}: {len(req.verification_proofs) if req.verification_proofs else 0} proofs"
609+
)
592610
if verification_proofs is None:
593611
verification_proofs = []
594612
verification_proofs.append(req.verification_proofs)

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@
1818
import json
1919
import logging
2020
import os
21-
import threading
2221
import time
2322
from dataclasses import dataclass
2423
from typing import List, Optional, Tuple, Union
2524

2625
import torch
2726
import torch.distributed as dist
28-
from toploc import build_proofs_base64
2927

3028
from sglang.srt.configs.device_config import DeviceConfig
3129
from sglang.srt.configs.load_config import LoadConfig
@@ -57,7 +55,7 @@
5755
)
5856
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
5957
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
60-
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
58+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6159
from sglang.srt.model_loader import get_model
6260
from sglang.srt.model_loader.loader import (
6361
DefaultModelLoader,
@@ -125,19 +123,6 @@ def __init__(
125123
self.req_to_token_pool = req_to_token_pool
126124
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
127125

128-
# Activation saving setup
129-
self.save_activations = server_args.toploc_fingerprint
130-
if self.save_activations:
131-
self.capture_hidden_mode = (
132-
CaptureHiddenMode.LAST
133-
) # Only capture final hidden state
134-
self.verification_algorithm = (
135-
VerificationAlgorithm.TOPLOC
136-
) # Set verification algorithm
137-
self.is_cuda_graph_capturing = (
138-
False # Flag to track CUDA graph capturing state
139-
)
140-
141126
# Model-specific adjustment
142127
self.model_specific_adjustment()
143128

@@ -921,9 +906,7 @@ def init_cuda_graphs(self):
921906
logger.info(
922907
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
923908
)
924-
925909
self.cuda_graph_runner = CudaGraphRunner(self)
926-
927910
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
928911
logger.info(
929912
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
@@ -978,8 +961,6 @@ def forward_idle(self, forward_batch: ForwardBatch):
978961
def forward(
979962
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
980963
) -> LogitsProcessorOutput:
981-
"""Run the forward pass."""
982-
# Run cuda graph if possible
983964
if (
984965
forward_batch.forward_mode.is_cuda_graph()
985966
and self.cuda_graph_runner

python/sglang/srt/openai_api/adapter.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,32 @@ async def generate_stream_resp():
732732
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
733733
completion_tokens[index] = content["meta_info"]["completion_tokens"]
734734

735+
if not stream_buffer: # The first chunk
736+
if request.echo:
737+
if isinstance(request.prompt, str):
738+
# for the case of single str prompts
739+
prompts = request.prompt
740+
elif isinstance(request.prompt, list):
741+
if isinstance(request.prompt[0], str):
742+
# for the case of multiple str prompts
743+
prompts = request.prompt[index // request.n]
744+
elif isinstance(request.prompt[0], int):
745+
# for the case of single token ids prompt
746+
prompts = tokenizer_manager.tokenizer.decode(
747+
request.prompt, skip_special_tokens=True
748+
)
749+
elif isinstance(request.prompt[0], list) and isinstance(
750+
request.prompt[0][0], int
751+
):
752+
# for the case of multiple token ids prompts
753+
prompts = tokenizer_manager.tokenizer.decode(
754+
request.prompt[index // request.n],
755+
skip_special_tokens=True,
756+
)
757+
758+
# Prepend prompt in response text.
759+
text = prompts + text
760+
735761
if request.logprobs is not None:
736762
# The first chunk and echo is enabled.
737763
if not stream_buffer and request.echo:
@@ -1070,6 +1096,9 @@ def v1_chat_generate_response(
10701096

10711097
finish_reason = ret_item["meta_info"]["finish_reason"]
10721098

1099+
tool_calls = None
1100+
text = ret_item["text"]
1101+
10731102
if isinstance(request, list):
10741103
tool_choice = request[idx].tool_choice
10751104
tools = request[idx].tools
@@ -1084,7 +1113,7 @@ def v1_chat_generate_response(
10841113
parser = ReasoningParser(
10851114
model_type=reasoning_parser, stream_reasoning=False
10861115
)
1087-
reasoning_text, text = parser.parse_non_stream(ret_item["text"])
1116+
reasoning_text, text = parser.parse_non_stream(text)
10881117
except Exception as e:
10891118
logger.error(f"Exception: {e}")
10901119
return create_error_response(
@@ -1093,10 +1122,8 @@ def v1_chat_generate_response(
10931122
)
10941123
else:
10951124
reasoning_text = None
1096-
text = ret_item["text"]
10971125

1098-
tool_calls = None
1099-
if tool_call_parser and tool_choice != "none" and tools:
1126+
if tool_choice != "none" and tools:
11001127
parser = FunctionCallParser(tools, tool_call_parser)
11011128
if parser.has_tool_call(text):
11021129
if finish_reason["type"] == "stop":
@@ -1122,6 +1149,12 @@ def v1_chat_generate_response(
11221149

11231150
# Extract verification proofs if available
11241151
verification_proofs = ret_item["meta_info"].get("verification_proofs", None)
1152+
if verification_proofs:
1153+
logger.debug(
1154+
f"Retrieved verification proofs from response: {len(verification_proofs)} proof sets"
1155+
)
1156+
else:
1157+
logger.debug("No verification proofs found in response")
11251158

11261159
if to_file:
11271160
# to make the choice data json serializable

python/sglang/srt/server_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,10 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
11341134
ServerArgs.add_cli_args(parser)
11351135
raw_args = parser.parse_args(argv)
11361136
server_args = ServerArgs.from_cli_args(raw_args)
1137+
if server_args.toploc_fingerprint:
1138+
logger.info(
1139+
f"TopLoc fingerprint verification enabled with topk={server_args.toploc_verification_topk}"
1140+
)
11371141
return server_args
11381142

11391143

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import logging
12
from typing import Optional
23

34
import torch
45
from toploc import build_proofs_base64
56

67
from sglang.srt.managers.schedule_batch import global_server_args_dict
78

9+
logger = logging.getLogger(__name__)
10+
811

912
def create_toploc_proofs(
1013
verification_hidden_states: Optional[torch.Tensor],
@@ -18,16 +21,34 @@ def create_toploc_proofs(
1821
Returns:
1922
The hidden states tensor moved to CPU or None if input was None
2023
"""
24+
if verification_hidden_states is None:
25+
logger.warning(
26+
"Attempted to create TopLoc proofs with None verification_hidden_states"
27+
)
28+
return None
29+
30+
logger.debug(
31+
f"Creating TopLoc proofs from tensor with shape {verification_hidden_states.shape}"
32+
)
2133

2234
# Move to CPU . Will have size [N,hidden] - each one should represent a "last token"
2335
verification_hidden_states = verification_hidden_states.detach().cpu()
2436

2537
topk = global_server_args_dict["toploc_verification_topk"]
38+
logger.debug(f"Using TopLoc verification topk={topk}")
2639

2740
# Will return N proofs
28-
return build_proofs_base64(
29-
verification_hidden_states,
30-
decode_batching_size=3,
31-
topk=topk,
32-
skip_prefill=False,
33-
)
41+
try:
42+
proofs = build_proofs_base64(
43+
verification_hidden_states,
44+
decode_batching_size=3,
45+
topk=topk,
46+
skip_prefill=False,
47+
)
48+
logger.debug(
49+
f"Successfully generated {len(proofs) if proofs else 0} TopLoc proofs"
50+
)
51+
return proofs
52+
except Exception as e:
53+
logger.error(f"Error generating TopLoc proofs: {str(e)}")
54+
return None

0 commit comments

Comments
 (0)