Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 228 additions & 75 deletions src/lm_polygraph/utils/vllm_with_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams, TokensPrompt
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput

from lm_polygraph.estimators import Estimator
Expand All @@ -53,6 +53,105 @@
log = logging.getLogger()


def _fill_prefix_gaps(
captured: Dict[str, np.ndarray],
metadata: Dict[str, Dict],
prompt_groups: Dict[str, List[str]],
prompt_tokens: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, np.ndarray]:
"""
Fill prefix-cache gaps in captured hidden states.

Phase 1 — Within-group: for identical prompts, copy the missing prefix
from the group donor (request with the most prefill tokens).

Phase 2 — Cross-group (requires prompt_tokens): vLLM caches the shared
chat-template prefix across different prompts. Uses token-level
longest-common-prefix (LCP) with a global donor to fill remaining gaps.

Args:
captured: {req_id: numpy_array} — deserialized hidden states
metadata: {req_id: {"total_computed": int, "prefill_tokens": int}}
prompt_groups: {group: [req_id, ...]} — requests with identical prompts
prompt_tokens: {req_id: [token_ids]} — for cross-group filling
Returns:
{req_id: numpy_array} with gaps filled.
"""
result = dict(captured)

# Track effective prefill per request (updated as we fill)
eff_prefill: Dict[str, int] = {
rid: metadata.get(rid, {}).get("prefill_tokens", 0) for rid in result
}

# Phase 1: within-group fill (identical prompts)
for _, req_ids in prompt_groups.items():
group_reqs = [rid for rid in req_ids if rid in result]
if len(group_reqs) < 2:
continue
donor_id = max(group_reqs, key=lambda rid: eff_prefill.get(rid, 0))
donor_prefill = eff_prefill[donor_id]
donor_arr = result[donor_id]
for req_id in group_reqs:
if req_id == donor_id:
continue
gap = donor_prefill - eff_prefill[req_id]
if gap > 0:
result[req_id] = np.concatenate(
[donor_arr[:gap], result[req_id]], axis=0
)
eff_prefill[req_id] = donor_prefill

# Phase 2: cross-group fill (shared token prefix)
if not prompt_tokens:
return result

all_req_ids = list(result.keys())
if len(all_req_ids) < 2:
return result

# Global donor: pick from requests with full prefill, then longest prompt
full_prefill_reqs = [
rid
for rid in all_req_ids
if eff_prefill.get(rid, 0) >= len(prompt_tokens.get(rid, []))
]
if not full_prefill_reqs:
return result

global_donor = max(
full_prefill_reqs, key=lambda rid: len(prompt_tokens.get(rid, []))
)
gd_prefill = eff_prefill[global_donor]
gd_tokens = prompt_tokens.get(global_donor, [])
gd_arr = result[global_donor]

for req_id in all_req_ids:
if req_id == global_donor:
continue
req_toks = prompt_tokens.get(req_id, [])
req_prompt_len = len(req_toks)
cur_prefill = eff_prefill[req_id]
if cur_prefill >= req_prompt_len:
continue

# Longest common prefix with global donor
lcp = 0
for a, b in zip(gd_tokens, req_toks):
if a == b:
lcp += 1
else:
break

missing = req_prompt_len - cur_prefill
fillable = min(missing, lcp, gd_prefill)
if fillable > 0:
result[req_id] = np.concatenate([gd_arr[:fillable], result[req_id]], axis=0)
eff_prefill[req_id] = cur_prefill + fillable

return result


def _safe_float_uncertainty(u: Any) -> float:
"""Robustly extract a scalar from estimator output."""
if isinstance(u, (float, int)):
Expand Down Expand Up @@ -575,10 +674,15 @@ def _raw_generate(
if self.prompt_logprobs:
sampling_params.prompt_logprobs = sampling_params.logprobs

# ---- 1) MAIN GENERATION (without HS capture) ----
# ---- Setup HS capture BEFORE generation if using native path ----
if self.output_hidden_states and self.use_native_hs_capture:
self._ensure_hs_extension()
self._engine_core.collective_rpc("_reset_capture")

# ---- 1) MAIN GENERATION (with HS capture if native path) ----
outputs = self.llm.generate(prompts_list, sampling_params)

# ---- 2) Setup HS capture for additional forward pass ----
# ---- 2) Extract hidden states ----
hs_by_req_out: List[List[Optional[List[torch.Tensor]]]] = [
[None] * len(ro.outputs) for ro in outputs
]
Expand All @@ -593,88 +697,137 @@ def _raw_generate(
)

if self.use_native_hs_capture:
# Path 2: Native capture - sequential additional forward pass
# Process each sequence separately to avoid prefix sharing issues
self._ensure_hs_extension()

# Create sampling params for 1 token generation
one_token_params = SamplingParams(
temperature=0,
max_tokens=1,
)

log.info(
f"Running sequential additional forward pass with HS capture for {len(flat_full_ids)} sequences"
)

# Process each sequence separately (batch_size=1)
for k, (i, j) in enumerate(flat_index):
seq_ids = flat_full_ids[k]
seq_len = len(seq_ids)

log.info(
f"Processing sequence {k+1}/{len(flat_full_ids)} (req_idx={i}, out_idx={j}, seq_len={seq_len})"
# Path 2: Native capture — HS were captured during main generation.
# Extract per-request hidden states and fill prefix-cache gaps.

per_rank = self._engine_core.collective_rpc("_get_captured_states")
captured_raw: dict = per_rank[0] if per_rank else {}
meta_raw = self._engine_core.collective_rpc("_get_capture_metadata")
meta: dict = meta_raw[0] if meta_raw else {}

# Map output request_id to prompt index in prompts_list.
# vLLM request ID formats:
# "6" — simple numeric
# "6-abc123" — numeric prefix + hash suffix
# "2_0" — format {seq_idx}_{prompt_idx} for best-of-n
output_short_ids: Dict[str, int] = {}
for i, out in enumerate(outputs):
output_short_ids[out.request_id] = i
# Index by numeric prefix: "6-ab" -> "6"
if "-" in out.request_id:
prefix = out.request_id.split("-")[0]
output_short_ids.setdefault(prefix, i)
# Index by suffix for {seq}_{prompt} format: "0_1" -> "1"
if "_" in out.request_id:
suffix = out.request_id.rsplit("_", 1)[-1]
output_short_ids.setdefault(suffix, i)

def _resolve_prompt_idx(req_id: str) -> Optional[int]:
# Exact match
if req_id in output_short_ids:
return output_short_ids[req_id]
# Format {seq}_{prompt}: suffix is prompt index
if "_" in req_id:
suffix = req_id.rsplit("_", 1)[-1]
if suffix in output_short_ids:
return output_short_ids[suffix]
# Format {prompt}-{hash}: prefix is prompt index
if "-" in req_id:
prefix = req_id.split("-")[0]
if prefix in output_short_ids:
return output_short_ids[prefix]
return None

def _resolve_seq_idx(req_id: str) -> Optional[int]:
"""Extract sequence index: '2_0' -> seq=2 (prefix)."""
if "_" in req_id:
parts = req_id.rsplit("_", 1)
try:
return int(parts[0])
except ValueError:
pass
return None

# Process each captured layer
for lid in self.hs_layer_ids:
layer_data = captured_raw.get(lid, {})
if not layer_data:
continue

# Deserialize per-request arrays
captured_arrays: Dict[str, np.ndarray] = {}
for req_id, pickled_bytes in layer_data.items():
captured_arrays[req_id] = pickle.loads(pickled_bytes)

# Build prompt groups and prompt_tokens for gap filling
prompt_groups: Dict[str, List[str]] = {}
req_prompt_tokens: Dict[str, List[int]] = {}
for req_id in captured_arrays:
idx = _resolve_prompt_idx(req_id)
if idx is None:
continue
text = prompts_list[idx]
prompt_groups.setdefault(text, []).append(req_id)
req_prompt_tokens[req_id] = prompt_token_ids[idx]

# Fill prefix-cache gaps
filled = _fill_prefix_gaps(
captured_arrays,
meta,
prompt_groups,
prompt_tokens=req_prompt_tokens,
)

# Reset capture buffer and prefix cache for each sequence
self._engine_core.collective_rpc("_reset_capture")
self._reset_hs_prefix_cache()

# Generate 1 token for THIS SEQUENCE ONLY (batch_size=1)
_ = self.llm.generate(
[TokensPrompt(prompt_token_ids=seq_ids)],
sampling_params=one_token_params,
)

# Extract captured states for this sequence
per_rank = self._engine_core.collective_rpc("_get_captured_states")
captured_raw: dict = per_rank[0] if per_rank else {}

# Deserialize pickled arrays
captured = {}
for lid, pickled_bytes in captured_raw.items():
arr = pickle.loads(pickled_bytes)
captured[lid] = torch.from_numpy(arr)
log.info(
f" Sequence {k}: Layer {lid}: captured shape={arr.shape}, dtype={arr.dtype}"
)

# Verify we captured exactly seq_len tokens
total_captured = 0
if captured:
first_lid = list(captured.keys())[0]
total_captured = captured[first_lid].shape[0]
log.info(
f" Sequence {k}: captured {total_captured} tokens, expected {seq_len}, match={total_captured == seq_len}"
)

# Form hidden states list for this sequence
layer_states = []
for lid in self.hs_layer_ids:
if lid in captured:
tensor = captured[lid]
# We expect exactly seq_len tokens for this single sequence
if tensor.shape[0] == seq_len:
layer_states.append(tensor)
else:
log.warning(
f" Sequence {k}: Layer {lid} has {tensor.shape[0]} tokens, expected {seq_len}"
# Assign filled HS to the right (request, output) slot.
# With best-of-n, vLLM creates sub-requests like "4_0",
# "4_1" — each with its own hidden states. Map the suffix
# to the correct output sequence index j.
for req_id, arr in filled.items():
idx = _resolve_prompt_idx(req_id)
if idx is None:
continue

seq_j = _resolve_seq_idx(req_id)
num_outputs = len(outputs[idx].outputs)

if seq_j is not None and seq_j < num_outputs:
# Specific sequence index from req_id (e.g. "4_0")
target_js = [seq_j]
else:
# No sequence suffix or n=1 — assign to all outputs
target_js = list(range(num_outputs))

for j in target_js:
out = outputs[idx].outputs[j]
gen_len = len(getattr(out, "token_ids", []) or [])
expected = len(prompt_token_ids[idx]) + gen_len

tensor = torch.from_numpy(arr)

# Trim if captured more than expected
if tensor.shape[0] > expected:
tensor = tensor[:expected]

# Pad with zero vectors if short
# (last generated token may not get a forward pass)
if tensor.shape[0] < expected:
pad_count = expected - tensor.shape[0]
pad = torch.zeros(
pad_count, tensor.shape[1], dtype=tensor.dtype
)
# Slice to expected length if needed
layer_states.append(tensor[:seq_len])
tensor = torch.cat([tensor, pad], dim=0)

hs_by_req_out[i][j] = (
{"hidden_states": layer_states} if layer_states else None
)
if hs_by_req_out[idx][j] is None:
hs_by_req_out[idx][j] = {"hidden_states": []}
hs_by_req_out[idx][j]["hidden_states"].append(tensor)

# Cleanup: remove hooks and reset capture buffer after generation
# Cleanup hooks
self._engine_core.collective_rpc(
"_setup_hidden_states_capture", args=([],)
)
self._engine_core.collective_rpc("_reset_capture")
# Reset flag so hooks are re-registered on next run
self._hs_extension_ready = False
log.info("Removed HS capture hooks and reset buffer after generation")
log.info("Native HS capture: extracted per-request hidden states")
else:
# Path 1: VllmHiddenStatesGenerator - separate generation
# Sleep main LLM to free GPU memory before HS generator loads
Expand Down
Loading