diff --git a/src/lm_polygraph/utils/vllm_with_uncertainty.py b/src/lm_polygraph/utils/vllm_with_uncertainty.py index 3fdd0a96..ce6e78ee 100644 --- a/src/lm_polygraph/utils/vllm_with_uncertainty.py +++ b/src/lm_polygraph/utils/vllm_with_uncertainty.py @@ -35,7 +35,7 @@ import numpy as np import torch from tqdm import tqdm -from vllm import LLM, SamplingParams, TokensPrompt +from vllm import LLM, SamplingParams from vllm.outputs import RequestOutput from lm_polygraph.estimators import Estimator @@ -53,6 +53,105 @@ log = logging.getLogger() +def _fill_prefix_gaps( + captured: Dict[str, np.ndarray], + metadata: Dict[str, Dict], + prompt_groups: Dict[str, List[str]], + prompt_tokens: Optional[Dict[str, List[int]]] = None, +) -> Dict[str, np.ndarray]: + """ + Fill prefix-cache gaps in captured hidden states. + + Phase 1 — Within-group: for identical prompts, copy the missing prefix + from the group donor (request with the most prefill tokens). + + Phase 2 — Cross-group (requires prompt_tokens): vLLM caches the shared + chat-template prefix across different prompts. Uses token-level + longest-common-prefix (LCP) with a global donor to fill remaining gaps. + + Args: + captured: {req_id: numpy_array} — deserialized hidden states + metadata: {req_id: {"total_computed": int, "prefill_tokens": int}} + prompt_groups: {group: [req_id, ...]} — requests with identical prompts + prompt_tokens: {req_id: [token_ids]} — for cross-group filling + Returns: + {req_id: numpy_array} with gaps filled. + """ + result = dict(captured) + + # Track effective prefill per request (updated as we fill) + eff_prefill: Dict[str, int] = { + rid: metadata.get(rid, {}).get("prefill_tokens", 0) for rid in result + } + + # Phase 1: within-group fill (identical prompts) + for _, req_ids in prompt_groups.items(): + group_reqs = [rid for rid in req_ids if rid in result] + if len(group_reqs) < 2: + continue + donor_id = max(group_reqs, key=lambda rid: eff_prefill.get(rid, 0)) + donor_prefill = eff_prefill[donor_id] + donor_arr = result[donor_id] + for req_id in group_reqs: + if req_id == donor_id: + continue + gap = donor_prefill - eff_prefill[req_id] + if gap > 0: + result[req_id] = np.concatenate( + [donor_arr[:gap], result[req_id]], axis=0 + ) + eff_prefill[req_id] = donor_prefill + + # Phase 2: cross-group fill (shared token prefix) + if not prompt_tokens: + return result + + all_req_ids = list(result.keys()) + if len(all_req_ids) < 2: + return result + + # Global donor: pick from requests with full prefill, then longest prompt + full_prefill_reqs = [ + rid + for rid in all_req_ids + if eff_prefill.get(rid, 0) >= len(prompt_tokens.get(rid, [])) + ] + if not full_prefill_reqs: + return result + + global_donor = max( + full_prefill_reqs, key=lambda rid: len(prompt_tokens.get(rid, [])) + ) + gd_prefill = eff_prefill[global_donor] + gd_tokens = prompt_tokens.get(global_donor, []) + gd_arr = result[global_donor] + + for req_id in all_req_ids: + if req_id == global_donor: + continue + req_toks = prompt_tokens.get(req_id, []) + req_prompt_len = len(req_toks) + cur_prefill = eff_prefill[req_id] + if cur_prefill >= req_prompt_len: + continue + + # Longest common prefix with global donor + lcp = 0 + for a, b in zip(gd_tokens, req_toks): + if a == b: + lcp += 1 + else: + break + + missing = req_prompt_len - cur_prefill + fillable = min(missing, lcp, gd_prefill) + if fillable > 0: + result[req_id] = np.concatenate([gd_arr[:fillable], result[req_id]], axis=0) + eff_prefill[req_id] = cur_prefill + fillable + + return result + + def _safe_float_uncertainty(u: Any) -> float: """Robustly extract a scalar from estimator output.""" if isinstance(u, (float, int)): @@ -575,10 +674,15 @@ def _raw_generate( if self.prompt_logprobs: sampling_params.prompt_logprobs = sampling_params.logprobs - # ---- 1) MAIN GENERATION (without HS capture) ---- + # ---- Setup HS capture BEFORE generation if using native path ---- + if self.output_hidden_states and self.use_native_hs_capture: + self._ensure_hs_extension() + self._engine_core.collective_rpc("_reset_capture") + + # ---- 1) MAIN GENERATION (with HS capture if native path) ---- outputs = self.llm.generate(prompts_list, sampling_params) - # ---- 2) Setup HS capture for additional forward pass ---- + # ---- 2) Extract hidden states ---- hs_by_req_out: List[List[Optional[List[torch.Tensor]]]] = [ [None] * len(ro.outputs) for ro in outputs ] @@ -593,88 +697,137 @@ def _raw_generate( ) if self.use_native_hs_capture: - # Path 2: Native capture - sequential additional forward pass - # Process each sequence separately to avoid prefix sharing issues - self._ensure_hs_extension() - - # Create sampling params for 1 token generation - one_token_params = SamplingParams( - temperature=0, - max_tokens=1, - ) - - log.info( - f"Running sequential additional forward pass with HS capture for {len(flat_full_ids)} sequences" - ) - - # Process each sequence separately (batch_size=1) - for k, (i, j) in enumerate(flat_index): - seq_ids = flat_full_ids[k] - seq_len = len(seq_ids) - - log.info( - f"Processing sequence {k+1}/{len(flat_full_ids)} (req_idx={i}, out_idx={j}, seq_len={seq_len})" + # Path 2: Native capture — HS were captured during main generation. + # Extract per-request hidden states and fill prefix-cache gaps. + + per_rank = self._engine_core.collective_rpc("_get_captured_states") + captured_raw: dict = per_rank[0] if per_rank else {} + meta_raw = self._engine_core.collective_rpc("_get_capture_metadata") + meta: dict = meta_raw[0] if meta_raw else {} + + # Map output request_id to prompt index in prompts_list. + # vLLM request ID formats: + # "6" — simple numeric + # "6-abc123" — numeric prefix + hash suffix + # "2_0" — format {seq_idx}_{prompt_idx} for best-of-n + output_short_ids: Dict[str, int] = {} + for i, out in enumerate(outputs): + output_short_ids[out.request_id] = i + # Index by numeric prefix: "6-ab" -> "6" + if "-" in out.request_id: + prefix = out.request_id.split("-")[0] + output_short_ids.setdefault(prefix, i) + # Index by suffix for {seq}_{prompt} format: "0_1" -> "1" + if "_" in out.request_id: + suffix = out.request_id.rsplit("_", 1)[-1] + output_short_ids.setdefault(suffix, i) + + def _resolve_prompt_idx(req_id: str) -> Optional[int]: + # Exact match + if req_id in output_short_ids: + return output_short_ids[req_id] + # Format {seq}_{prompt}: suffix is prompt index + if "_" in req_id: + suffix = req_id.rsplit("_", 1)[-1] + if suffix in output_short_ids: + return output_short_ids[suffix] + # Format {prompt}-{hash}: prefix is prompt index + if "-" in req_id: + prefix = req_id.split("-")[0] + if prefix in output_short_ids: + return output_short_ids[prefix] + return None + + def _resolve_seq_idx(req_id: str) -> Optional[int]: + """Extract sequence index: '2_0' -> seq=2 (prefix).""" + if "_" in req_id: + parts = req_id.rsplit("_", 1) + try: + return int(parts[0]) + except ValueError: + pass + return None + + # Process each captured layer + for lid in self.hs_layer_ids: + layer_data = captured_raw.get(lid, {}) + if not layer_data: + continue + + # Deserialize per-request arrays + captured_arrays: Dict[str, np.ndarray] = {} + for req_id, pickled_bytes in layer_data.items(): + captured_arrays[req_id] = pickle.loads(pickled_bytes) + + # Build prompt groups and prompt_tokens for gap filling + prompt_groups: Dict[str, List[str]] = {} + req_prompt_tokens: Dict[str, List[int]] = {} + for req_id in captured_arrays: + idx = _resolve_prompt_idx(req_id) + if idx is None: + continue + text = prompts_list[idx] + prompt_groups.setdefault(text, []).append(req_id) + req_prompt_tokens[req_id] = prompt_token_ids[idx] + + # Fill prefix-cache gaps + filled = _fill_prefix_gaps( + captured_arrays, + meta, + prompt_groups, + prompt_tokens=req_prompt_tokens, ) - # Reset capture buffer and prefix cache for each sequence - self._engine_core.collective_rpc("_reset_capture") - self._reset_hs_prefix_cache() - - # Generate 1 token for THIS SEQUENCE ONLY (batch_size=1) - _ = self.llm.generate( - [TokensPrompt(prompt_token_ids=seq_ids)], - sampling_params=one_token_params, - ) - - # Extract captured states for this sequence - per_rank = self._engine_core.collective_rpc("_get_captured_states") - captured_raw: dict = per_rank[0] if per_rank else {} - - # Deserialize pickled arrays - captured = {} - for lid, pickled_bytes in captured_raw.items(): - arr = pickle.loads(pickled_bytes) - captured[lid] = torch.from_numpy(arr) - log.info( - f" Sequence {k}: Layer {lid}: captured shape={arr.shape}, dtype={arr.dtype}" - ) - - # Verify we captured exactly seq_len tokens - total_captured = 0 - if captured: - first_lid = list(captured.keys())[0] - total_captured = captured[first_lid].shape[0] - log.info( - f" Sequence {k}: captured {total_captured} tokens, expected {seq_len}, match={total_captured == seq_len}" - ) - - # Form hidden states list for this sequence - layer_states = [] - for lid in self.hs_layer_ids: - if lid in captured: - tensor = captured[lid] - # We expect exactly seq_len tokens for this single sequence - if tensor.shape[0] == seq_len: - layer_states.append(tensor) - else: - log.warning( - f" Sequence {k}: Layer {lid} has {tensor.shape[0]} tokens, expected {seq_len}" + # Assign filled HS to the right (request, output) slot. + # With best-of-n, vLLM creates sub-requests like "4_0", + # "4_1" — each with its own hidden states. Map the suffix + # to the correct output sequence index j. + for req_id, arr in filled.items(): + idx = _resolve_prompt_idx(req_id) + if idx is None: + continue + + seq_j = _resolve_seq_idx(req_id) + num_outputs = len(outputs[idx].outputs) + + if seq_j is not None and seq_j < num_outputs: + # Specific sequence index from req_id (e.g. "4_0") + target_js = [seq_j] + else: + # No sequence suffix or n=1 — assign to all outputs + target_js = list(range(num_outputs)) + + for j in target_js: + out = outputs[idx].outputs[j] + gen_len = len(getattr(out, "token_ids", []) or []) + expected = len(prompt_token_ids[idx]) + gen_len + + tensor = torch.from_numpy(arr) + + # Trim if captured more than expected + if tensor.shape[0] > expected: + tensor = tensor[:expected] + + # Pad with zero vectors if short + # (last generated token may not get a forward pass) + if tensor.shape[0] < expected: + pad_count = expected - tensor.shape[0] + pad = torch.zeros( + pad_count, tensor.shape[1], dtype=tensor.dtype ) - # Slice to expected length if needed - layer_states.append(tensor[:seq_len]) + tensor = torch.cat([tensor, pad], dim=0) - hs_by_req_out[i][j] = ( - {"hidden_states": layer_states} if layer_states else None - ) + if hs_by_req_out[idx][j] is None: + hs_by_req_out[idx][j] = {"hidden_states": []} + hs_by_req_out[idx][j]["hidden_states"].append(tensor) - # Cleanup: remove hooks and reset capture buffer after generation + # Cleanup hooks self._engine_core.collective_rpc( "_setup_hidden_states_capture", args=([],) ) self._engine_core.collective_rpc("_reset_capture") - # Reset flag so hooks are re-registered on next run self._hs_extension_ready = False - log.info("Removed HS capture hooks and reset buffer after generation") + log.info("Native HS capture: extracted per-request hidden states") else: # Path 1: VllmHiddenStatesGenerator - separate generation # Sleep main LLM to free GPU memory before HS generator loads