diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index d054cf2c52..2b3016fd4f 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -199,6 +199,7 @@ def _compute_sample_from_openai_record( case "abort": sample.status = Sample.Status.ABORTED + sample.prefix_cache_info.add(choice.get("meta_info", {})) if "weight_version" in choice["meta_info"]: sample.weight_versions.append(choice["meta_info"]["weight_version"]) diff --git a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py index e8cb2eb340..6d90fdd38b 100644 --- a/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py +++ b/tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py @@ -49,6 +49,8 @@ def _make_record( output_token_ids: list[int], output_log_probs: list[float] | None = None, finish_reason: str = "stop", + cached_tokens: int | None = None, + prompt_tokens: int | None = None, ) -> SessionRecord: """Build a minimal session record mimicking SGLang's response format. @@ -62,6 +64,14 @@ def _make_record( logprobs_content = [ {"logprob": lp, "token": f"t{tid}"} for tid, lp in zip(output_token_ids, output_log_probs, strict=True) ] + meta_info = { + "output_token_logprobs": output_token_logprobs, + "completion_tokens": len(output_token_ids), + } + if cached_tokens is not None: + meta_info["cached_tokens"] = cached_tokens + if prompt_tokens is not None: + meta_info["prompt_tokens"] = prompt_tokens return SessionRecord( timestamp=0.0, method="POST", @@ -75,10 +85,7 @@ def _make_record( "message": {"role": "assistant", "content": "response"}, "finish_reason": finish_reason, "logprobs": {"content": logprobs_content}, - "meta_info": { - "output_token_logprobs": output_token_logprobs, - "completion_tokens": len(output_token_ids), - }, + "meta_info": meta_info, } ] }, @@ -557,3 +564,65 @@ def test_no_thinking_tokens_prefix_chain_holds(self): merged = merge_samples(samples, tok) assert merged.tokens == [1, 2, 3, 10, 11, 20, 21, 30, 31] + + +# ── test: prefix cache info population ──────────────────────────────── + + +class TestPrefixCacheInfo: + """Validate that prefix cache statistics from meta_info are collected.""" + + def test_single_record_with_cache_stats(self): + """cached_tokens and prompt_tokens from meta_info populate prefix_cache_info.""" + tok = _mock_tokenizer() + record = _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + cached_tokens=2, + prompt_tokens=3, + ) + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok) + + assert samples[0].prefix_cache_info.cached_tokens == 2 + assert samples[0].prefix_cache_info.total_prompt_tokens == 3 + + def test_multi_turn_cache_stats_accumulate_after_merge(self): + """After merge_samples, prefix_cache_info sums across turns.""" + tok = _mock_tokenizer() + records = [ + _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + output_log_probs=[-0.1, -0.2], + cached_tokens=0, + prompt_tokens=3, + ), + _make_record( + prompt_token_ids=[1, 2, 3, 10, 11, 20, 21], + output_token_ids=[30, 31], + output_log_probs=[-0.3, -0.4], + cached_tokens=5, + prompt_tokens=7, + ), + ] + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, records, tok) + merged = merge_samples(samples, tok) + + assert merged.prefix_cache_info.cached_tokens == 0 + 5 + assert merged.prefix_cache_info.total_prompt_tokens == 3 + 7 + assert merged.prefix_cache_info.prefix_cache_hit_rate == 5 / 10 + + def test_missing_cache_fields_default_to_zero(self): + """Records without cached_tokens/prompt_tokens give zero prefix_cache_info (regression).""" + tok = _mock_tokenizer() + record = _make_record( + prompt_token_ids=[1, 2, 3], + output_token_ids=[10, 11], + ) + input_sample = _make_input_sample() + samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok) + + assert samples[0].prefix_cache_info.cached_tokens == 0 + assert samples[0].prefix_cache_info.total_prompt_tokens == 0