Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions miles/rollout/generate_utils/openai_endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def _compute_sample_from_openai_record(
case "abort":
sample.status = Sample.Status.ABORTED

sample.prefix_cache_info.add(choice.get("meta_info", {}))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this correctly populates the prefix cache information, it is inconsistent with the rest of the function which accesses meta_info directly (e.g., lines 159-160). If meta_info is missing, the function will have already raised a KeyError before reaching this line. For better consistency and safety, consider extracting meta_info once at the beginning of the function and using it throughout.


return sample


Expand Down
77 changes: 73 additions & 4 deletions tests/fast/rollout/generate_utils/test_openai_endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _make_record(
output_token_ids: list[int],
output_log_probs: list[float] | None = None,
finish_reason: str = "stop",
cached_tokens: int | None = None,
prompt_tokens: int | None = None,
) -> SessionRecord:
"""Build a minimal session record mimicking SGLang's response format.

Expand All @@ -59,6 +61,14 @@ def _make_record(
logprobs_content = [
{"logprob": lp, "token": f"t{tid}"} for tid, lp in zip(output_token_ids, output_log_probs, strict=True)
]
meta_info = {
"output_token_logprobs": output_token_logprobs,
"completion_tokens": len(output_token_ids),
}
if cached_tokens is not None:
meta_info["cached_tokens"] = cached_tokens
if prompt_tokens is not None:
meta_info["prompt_tokens"] = prompt_tokens
return SessionRecord(
timestamp=0.0,
method="POST",
Expand All @@ -72,10 +82,7 @@ def _make_record(
"message": {"role": "assistant", "content": "response"},
"finish_reason": finish_reason,
"logprobs": {"content": logprobs_content},
"meta_info": {
"output_token_logprobs": output_token_logprobs,
"completion_tokens": len(output_token_ids),
},
"meta_info": meta_info,
}
]
},
Expand Down Expand Up @@ -527,3 +534,65 @@ def test_no_thinking_tokens_prefix_chain_holds(self):
merged = merge_samples(samples, tok)

assert merged.tokens == [1, 2, 3, 10, 11, 20, 21, 30, 31]


# ── test: prefix cache info population ────────────────────────────────


class TestPrefixCacheInfo:
"""Validate that prefix cache statistics from meta_info are collected."""

def test_single_record_with_cache_stats(self):
"""cached_tokens and prompt_tokens from meta_info populate prefix_cache_info."""
tok = _mock_tokenizer()
record = _make_record(
prompt_token_ids=[1, 2, 3],
output_token_ids=[10, 11],
cached_tokens=2,
prompt_tokens=3,
)
input_sample = _make_input_sample()
samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok)

assert samples[0].prefix_cache_info.cached_tokens == 2
assert samples[0].prefix_cache_info.total_prompt_tokens == 3

def test_multi_turn_cache_stats_accumulate_after_merge(self):
"""After merge_samples, prefix_cache_info sums across turns."""
tok = _mock_tokenizer()
records = [
_make_record(
prompt_token_ids=[1, 2, 3],
output_token_ids=[10, 11],
output_log_probs=[-0.1, -0.2],
cached_tokens=0,
prompt_tokens=3,
),
_make_record(
prompt_token_ids=[1, 2, 3, 10, 11, 20, 21],
output_token_ids=[30, 31],
output_log_probs=[-0.3, -0.4],
cached_tokens=5,
prompt_tokens=7,
),
]
input_sample = _make_input_sample()
samples = compute_samples_from_openai_records(_ARGS, input_sample, records, tok)
merged = merge_samples(samples, tok)

assert merged.prefix_cache_info.cached_tokens == 0 + 5
assert merged.prefix_cache_info.total_prompt_tokens == 3 + 7
assert merged.prefix_cache_info.prefix_cache_hit_rate == 5 / 10

def test_missing_cache_fields_default_to_zero(self):
"""Records without cached_tokens/prompt_tokens give zero prefix_cache_info (regression)."""
tok = _mock_tokenizer()
record = _make_record(
prompt_token_ids=[1, 2, 3],
output_token_ids=[10, 11],
)
input_sample = _make_input_sample()
samples = compute_samples_from_openai_records(_ARGS, input_sample, [record], tok)

assert samples[0].prefix_cache_info.cached_tokens == 0
assert samples[0].prefix_cache_info.total_prompt_tokens == 0
Loading