[trainer,algo] feat: Support On-Policy Distillation in main_ppo_sync#5997
[trainer,algo] feat: Support On-Policy Distillation in main_ppo_sync#59970oshowero0 wants to merge 1 commit intoverl-project:mainfrom
main_ppo_sync#5997Conversation
Signed-off-by: 0oshowero0 <[email protected]> update Signed-off-by: 0oshowero0 <[email protected]> update Signed-off-by: 0oshowero0 <[email protected]> fix Signed-off-by: 0oshowero0 <[email protected]>
There was a problem hiding this comment.
Code Review
This pull request implements teacher model distillation support in the PPO synchronization loop, including the addition of a TeacherModelManager and updates to the agent loop and transfer queue utilities to handle teacher logprobs and KVBatchMeta. Review feedback highlights a potential KeyError in agent_loop.py when extra_fields is missing and a crash in main_ppo_sync.py when processing multiple agent loop outputs.
| teacher_ids, teacher_logprobs = ( | ||
| output["extra_fields"].pop("teacher_ids", None), | ||
| output["extra_fields"].pop("teacher_logprobs", None), | ||
| ) |
There was a problem hiding this comment.
The code assumes extra_fields is always present in the output dictionary. However, since output is created using self.model_dump(exclude_unset=True) at line 212, if extra_fields was not explicitly set (it has a default empty dict), it will be excluded from the resulting dictionary. This will cause a KeyError when attempting to access output["extra_fields"]. It is safer to use .get() or ensure the key exists before popping.
extra_fields = output.get("extra_fields", {})
teacher_ids, teacher_logprobs = (
extra_fields.pop("teacher_ids", None),
extra_fields.pop("teacher_logprobs", None),
)| await self._compute_teacher_logprobs( | ||
| output, | ||
| prompt_ids=output.prompt_ids, | ||
| response_ids=output.response_ids, | ||
| validate=validate, | ||
| ) |
There was a problem hiding this comment.
The code at lines 377-382 will crash if output is a list of AgentLoopOutput objects, as it attempts to access .prompt_ids and .response_ids directly on the output variable. While there is a TODO comment at line 376, the current implementation is broken for the multi-output case which is explicitly handled earlier in the function (line 350). You should iterate over the outputs list to compute teacher logprobs for each item.
| await self._compute_teacher_logprobs( | |
| output, | |
| prompt_ids=output.prompt_ids, | |
| response_ids=output.response_ids, | |
| validate=validate, | |
| ) | |
| # TODO: Support output:list[AgentLoopOutput] | |
| for out in outputs: | |
| await self._compute_teacher_logprobs( | |
| out, | |
| prompt_ids=out.prompt_ids, | |
| response_ids=out.response_ids, | |
| validate=validate, | |
| ) |
What does this PR do?
Support #5041 in
main_ppo_sync.pyChecklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.