Skip to content
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5b8d9b7
feat: add extra param sync groups for checkpoint engine
sl-1314 Mar 12, 2026
f9f0133
feature: support old_log_prob_server
sl-1314 Mar 20, 2026
7829343
fix: unify configs
sl-1314 Mar 31, 2026
2d48570
fix: add wait for
sl-1314 Mar 31, 2026
734cd3e
fix: restore debug code
sl-1314 Mar 31, 2026
70bccbd
fix: refator OldLogProbServer
sl-1314 Apr 1, 2026
bb9e96b
fix: run e2e exp
sl-1314 Apr 7, 2026
9c7f927
Revert "feat: add extra param sync groups for checkpoint engine"
sl-1314 Apr 8, 2026
a2c8d3f
fix: misc
sl-1314 Apr 8, 2026
2f4b7fa
Merge branch 'main' into standalone_old_log_prob_support
sl-1314 Apr 8, 2026
15b0eb6
refactor: refactor old_log_prob_server to rollout replica
sl-1314 Apr 9, 2026
adf2e16
fix: revert unnessary change
sl-1314 Apr 9, 2026
8a01834
refacor: move old_log_prob_server to fully_async/
sl-1314 Apr 9, 2026
944c9f5
fix: resotre unnessary code
sl-1314 Apr 9, 2026
c41f236
fix: resotre unnessary code
sl-1314 Apr 9, 2026
5707a9c
simplify old_log_prob arch
sl-1314 Apr 9, 2026
59d3cd7
update: notes
sl-1314 Apr 9, 2026
9c97293
fix: misc
sl-1314 Apr 9, 2026
024b053
feat: support per tensor load
sl-1314 Apr 9, 2026
0ff02e4
update: misc
sl-1314 Apr 10, 2026
ea93e97
fix: remove redundant param
sl-1314 Apr 10, 2026
060b0d2
update: rename
sl-1314 Apr 13, 2026
3a60c6c
mv
sl-1314 Apr 13, 2026
3666fd1
fix: rename misc
sl-1314 Apr 13, 2026
2a0cb1f
fix: clean unused code
sl-1314 Apr 13, 2026
2adb3cb
fix: pre commit
sl-1314 Apr 13, 2026
08b92ce
fix: mv config
sl-1314 Apr 13, 2026
37ac30c
fix: misc
sl-1314 Apr 13, 2026
cedf3e4
update: remove unused code
sl-1314 Apr 13, 2026
a8e27ba
update: remove unused code
sl-1314 Apr 13, 2026
7d75b0a
Update verl/experimental/agent_loop/tool_agent_loop.py
sl-1314 Apr 13, 2026
0ff7de2
Merge branch 'standalone_old_log_prob_support' of https://github.com/…
sl-1314 Apr 13, 2026
999fb68
update: revise review comments
sl-1314 Apr 13, 2026
4b04e7b
update: rename variables
sl-1314 Apr 13, 2026
752a831
update: remove mbridge code
sl-1314 Apr 13, 2026
3be4d98
update: rename
sl-1314 Apr 14, 2026
98f7c4a
update: mv to extra_fields
sl-1314 Apr 14, 2026
354b026
update: remove other fix
sl-1314 Apr 14, 2026
c920d62
update: simplify validate code
sl-1314 Apr 14, 2026
00252c5
update: remove shutdown
sl-1314 Apr 14, 2026
3559cd8
update: misc
sl-1314 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"verl/checkpoint_engine", # checkpoint engine backend are device specific
"verl/utils/modelopt/megatron_qat_patch.py", # appear in torch.cuda.empty_cache()
"verl/models/mcore/patch.py", # checkpoint patch only on cuda
"verl/verl/workers/rollout/model_engine_server/model_engine_server.py", # appear in default device_name
]

# directory or file path must contain keyword "nccl"
Expand Down
18 changes: 18 additions & 0 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
sampling_params["top_p"] = config.val_kwargs.top_p
sampling_params["top_k"] = config.val_kwargs.top_k
sampling_params["temperature"] = config.val_kwargs.temperature
sampling_params["validate"] = True

# by default, we assume it's a single turn agent
if "agent_name" not in batch.non_tensor_batch:
Expand Down Expand Up @@ -675,6 +676,14 @@ async def _agent_loop_postprocess(self, output, validate, **kwargs) -> _Internal
if output.response_logprobs is not None:
pad_size = self.rollout_config.response_length - len(output.response_logprobs)
response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0)
_es_logprobs = output.extra_fields.pop("engine_server_logprobs", None)
if _es_logprobs is not None:
pad_size = self.rollout_config.response_length - len(_es_logprobs)
output.extra_fields["engine_server_logprobs"] = torch.tensor(_es_logprobs + [0.0] * pad_size).unsqueeze(0)
_es_entropys = output.extra_fields.pop("engine_server_entropys", None)
if _es_entropys is not None:
pad_size = self.rollout_config.response_length - len(_es_entropys)
output.extra_fields["engine_server_entropys"] = torch.tensor(_es_entropys + [0.0] * pad_size).unsqueeze(0)

response_mask = response_mask_output["input_ids"] * response_output["attention_mask"]
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)
Expand Down Expand Up @@ -884,6 +893,14 @@ def _postprocess(
optional_outputs = {}
if inputs[0].response_logprobs is not None:
optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
if inputs[0].extra_fields.get("engine_server_logprobs") is not None:
optional_outputs["engine_server_logprobs"] = torch.cat(
[input.extra_fields.pop("engine_server_logprobs") for input in inputs], dim=0
)
if inputs[0].extra_fields.get("engine_server_entropys") is not None:
optional_outputs["engine_server_entropys"] = torch.cat(
[input.extra_fields.pop("engine_server_entropys") for input in inputs], dim=0
)
if inputs[0].routed_experts is not None:
optional_outputs["routed_experts"] = torch.cat([input.routed_experts for input in inputs], dim=0)
if inputs[0].teacher_logprobs is not None and inputs[0].teacher_ids is not None:
Expand Down Expand Up @@ -1132,6 +1149,7 @@ async def _init_agent_loop_workers(self):
teacher_servers,
teacher_load_balancer_handle,
self.reward_loop_worker_handles,
self.model_engine_server_handle,
)
)

Expand Down
8 changes: 7 additions & 1 deletion verl/experimental/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1
response_mask = [1] * len(output.token_ids)

extra_fields = output.extra_fields
if extra_fields.get("engine_server_logprobs"):
extra_fields["engine_server_logprobs"] = extra_fields["engine_server_logprobs"][: self.response_length]
if extra_fields.get("engine_server_entropys"):
extra_fields["engine_server_entropys"] = extra_fields["engine_server_entropys"][: self.response_length]

output: AgentLoopOutput = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=output.token_ids[: self.response_length],
Expand All @@ -81,7 +87,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
multi_modal_data=multi_modal_data,
num_turns=2,
metrics=metrics,
extra_fields=output.extra_fields,
extra_fields=extra_fields,
)

# keeping the schema consistent with tool_agent_loop
Expand Down
26 changes: 23 additions & 3 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
if agent_data.video_data is not None:
multi_modal_data["videos"] = agent_data.video_data

extra_fields = agent_data.extra_fields
if extra_fields.get("engine_server_logprobs"):
extra_fields["engine_server_logprobs"] = extra_fields["engine_server_logprobs"][: self.response_length]
if extra_fields.get("engine_server_entropys"):
extra_fields["engine_server_entropys"] = extra_fields["engine_server_entropys"][: self.response_length]

output: AgentLoopOutput = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
Expand All @@ -195,7 +201,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
routed_experts=agent_data.routed_experts,
extra_fields=agent_data.extra_fields,
extra_fields=extra_fields,
)
output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards})
return output
Expand Down Expand Up @@ -232,8 +238,9 @@ async def _handle_generating_state(
else:
agent_data.metrics["num_preempted"] += output.num_preempted if output.num_preempted is not None else 0

_ACCUMULATED_KEYS = {"engine_server_logprobs", "engine_server_entropys"}
if not agent_data.extra_fields:
agent_data.extra_fields.update(output.extra_fields)
agent_data.extra_fields.update({k: v for k, v in output.extra_fields.items() if k not in _ACCUMULATED_KEYS})
else:
# Multi-round calls, only update the maximum max_global_steps.
max_global_steps = output.extra_fields.get("max_global_steps", None)
Expand All @@ -246,7 +253,12 @@ async def _handle_generating_state(
agent_data.response_mask += [1] * len(agent_data.response_ids)
if output.log_probs:
agent_data.response_logprobs += output.log_probs

if output.extra_fields.get("engine_server_logprobs"):
agent_data.extra_fields.setdefault("engine_server_logprobs", [])
agent_data.extra_fields["engine_server_logprobs"] += output.extra_fields["engine_server_logprobs"]
if output.extra_fields.get("engine_server_entropys"):
agent_data.extra_fields.setdefault("engine_server_entropys", [])
agent_data.extra_fields["engine_server_entropys"] += output.extra_fields["engine_server_entropys"]
if output.routed_experts is not None:
agent_data.routed_experts = output.routed_experts

Expand Down Expand Up @@ -378,6 +390,10 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
agent_data.response_mask += [0] * len(response_ids)
if agent_data.response_logprobs:
agent_data.response_logprobs += [0.0] * len(response_ids)
if agent_data.extra_fields.get("engine_server_logprobs"):
agent_data.extra_fields["engine_server_logprobs"] += [0.0] * len(response_ids)
if agent_data.extra_fields.get("engine_server_entropys"):
agent_data.extra_fields["engine_server_entropys"] += [0.0] * len(response_ids)
agent_data.user_turns += 1
return AgentState.GENERATING

Expand Down Expand Up @@ -410,6 +426,10 @@ async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
agent_data.response_mask += [0] * len(response_ids)
if agent_data.response_logprobs:
agent_data.response_logprobs += [0.0] * len(response_ids)
if agent_data.extra_fields.get("engine_server_logprobs"):
agent_data.extra_fields["engine_server_logprobs"] += [0.0] * len(response_ids)
if agent_data.extra_fields.get("engine_server_entropys"):
agent_data.extra_fields["engine_server_entropys"] += [0.0] * len(response_ids)

# double check prompt
# Check termination condition
Expand Down
90 changes: 88 additions & 2 deletions verl/experimental/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class FullyAsyncLLMServerManager(AsyncLLMServerManager):
invisible to the AgentLoop.
"""

def __init__(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
model_engine_server_handle: ray.actor.ActorHandle = None,
):
super().__init__(config=config, servers=servers, load_balancer_handle=load_balancer_handle)
self.model_engine_server_handle = model_engine_server_handle

@rollout_trace_op
async def generate(
self,
Expand Down Expand Up @@ -74,6 +84,7 @@ async def generate(
elif "max_new_tokens" in sampling_params:
limit_key = "max_new_tokens"
original_max_tokens = sampling_params.get(limit_key) if limit_key else None
validate = sampling_params.pop("validate", None)

final_output = TokenOutput(
token_ids=[],
Expand All @@ -91,11 +102,25 @@ async def generate(
image_data=image_data,
video_data=video_data,
)

current_prompt_ids = prompt_ids + final_output.token_ids
current_temperature = sampling_params.get("temperature", 1.0)
# Skip old log prob computation during validation
if not validate:
output = await self._compute_old_log_prob(output, current_prompt_ids, current_temperature)
# 2. merge output into final_output
final_output.token_ids.extend(output.token_ids)
if output.log_probs is not None:
final_output.log_probs.extend(output.log_probs)
if output.extra_fields.get("engine_server_logprobs") is not None:
final_output.extra_fields.setdefault("engine_server_logprobs", [])
final_output.extra_fields["engine_server_logprobs"].extend(
output.extra_fields["engine_server_logprobs"]
)
if output.extra_fields.get("engine_server_entropys") is not None:
final_output.extra_fields.setdefault("engine_server_entropys", [])
final_output.extra_fields["engine_server_entropys"].extend(
output.extra_fields["engine_server_entropys"]
)
if output.routed_experts is not None:
if final_output.routed_experts is None:
final_output.routed_experts = output.routed_experts
Expand All @@ -121,11 +146,30 @@ async def generate(
# 4. check stop reason
if output.stop_reason not in ("aborted", "abort") or not self.config.async_training.partial_rollout:
break
if validate: # restore for multi-turn use
sampling_params["validate"] = validate
final_output.extra_fields["global_steps"] = global_steps
final_output.extra_fields["min_global_steps"] = min_global_steps
final_output.extra_fields["max_global_steps"] = max_global_steps
return final_output

async def _compute_old_log_prob(self, output: TokenOutput, context_prompt_ids, temperature: float):
if self.model_engine_server_handle is None:
return output

# Only recompute old_log_probs for newly generated tokens in this turn.
if len(output.token_ids) == 0:
output.extra_fields["engine_server_logprobs"] = []
output.extra_fields["engine_server_entropys"] = []
return output

result = await self.model_engine_server_handle.compute_log_prob.remote(
context_prompt_ids, output.token_ids, temperature
)
output.extra_fields["engine_server_logprobs"] = result["log_probs"]
output.extra_fields["engine_server_entropys"] = result["entropy"]
return output


@ray.remote
class FullyAsyncAgentLoopWorker(AgentLoopWorker):
Expand All @@ -137,8 +181,14 @@ def __init__(
teacher_servers: list[tuple[str, ray.actor.ActorHandle]] = None,
teacher_load_balancer_handle: ray.actor.ActorHandle = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
model_engine_server_handle: ray.actor.ActorHandle = None,
):
self.server_manager = FullyAsyncLLMServerManager(config, servers, load_balancer_handle)
self.server_manager = FullyAsyncLLMServerManager(
config,
servers,
load_balancer_handle,
model_engine_server_handle,
)
super().__init__(
config,
servers,
Expand All @@ -163,6 +213,42 @@ def __init__(
if self.distillation_enabled:
raise NotImplementedError("Distillation is not implemented in FullyAsyncAgentLoopManager yet.")

async def _initialize_llm_servers(self):
"""Extend base class to also create ModelEngineReplica when configured."""
await super()._initialize_llm_servers()
if self.config.model_engine_server.enable_standalone:
await self._init_model_engine_replica()

async def _init_model_engine_replica(self):
"""Create ModelEngineReplica, call init_standalone, and append to rollout_replicas.

ModelEngineReplica.init_standalone() self-allocates a Ray resource pool,
spawns ModelEngineWorker actors, calls init_model(), and creates the
OldLogProbServer — all in one call, exactly like vLLM/SGLang replicas.

After this, self.model_engine_server_handle is set so that
_init_agent_loop_workers() passes it to every FullyAsyncAgentLoopWorker.
"""
from verl.workers.rollout.model_engine_server import ModelEngineReplica, ModelEngineWorker

replica = ModelEngineReplica(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be multi-DP?

replica_rank=len(self.rollout_replicas),
full_config=self.config,
worker_cls=ModelEngineWorker,
rollout_config=self.rollout_config,
)
await replica.init_standalone()
self.rollout_replicas.append(replica)

# Expose the server handle so _init_agent_loop_workers passes it to workers.
self.model_engine_server_handle = replica.servers[0]

async def shutdown(self):
"""Shut down OldLogProbServer if one was created."""
for replica in self.rollout_replicas:
if hasattr(replica, "shutdown"):
await replica.shutdown()

@auto_await
async def generate_sequences_single(self, prompts: DataProto) -> DataProto:
"""Split input batch and dispatch to agent loop workers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ hydra:

defaults:
- ppo_megatron_trainer

- model_engine_server@model_engine_server: megatron_model_engine_server

- _self_

trainer:
Expand Down
4 changes: 4 additions & 0 deletions verl/experimental/fully_async_policy/fully_async_rollouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ async def fit(self):
# Wait for the task to complete
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)

# Shut down OldLogProbServer (if any) — it lives on the Rollouter side.
if self.async_rollout_manager is not None:
await self.async_rollout_manager.shutdown()

print("[FullyAsyncRollouter] Rollouter fit completed")

async def _async_monitor_loop(self):
Expand Down
8 changes: 8 additions & 0 deletions verl/experimental/fully_async_policy/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,14 @@ def _compute_old_log_prob(self, batch: DataProto):
If local_trigger_step == 2, 3, ..., restore the parameters of version 1 to calculate the old_log_prob,
then restore the parameters of the current version.
"""
if "engine_server_logprobs" in batch.batch:
batch_dict = {
"old_log_probs": batch.batch.pop("engine_server_logprobs"),
"entropys": batch.batch.pop("engine_server_entropys"),
}
old_log_prob = DataProto.from_dict(batch_dict)
return old_log_prob, 0.0

if self.local_trigger_step == 1:
self.actor_rollout_wg.save_model_to_cpu(1)
old_log_prob, old_log_prob_mfu = super()._compute_old_log_prob(batch)
Expand Down
Loading