Skip to content
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5b8d9b7
feat: add extra param sync groups for checkpoint engine
sl-1314 Mar 12, 2026
f9f0133
feature: support old_log_prob_server
sl-1314 Mar 20, 2026
7829343
fix: unify configs
sl-1314 Mar 31, 2026
2d48570
fix: add wait for
sl-1314 Mar 31, 2026
734cd3e
fix: restore debug code
sl-1314 Mar 31, 2026
70bccbd
fix: refator OldLogProbServer
sl-1314 Apr 1, 2026
bb9e96b
fix: run e2e exp
sl-1314 Apr 7, 2026
9c7f927
Revert "feat: add extra param sync groups for checkpoint engine"
sl-1314 Apr 8, 2026
a2c8d3f
fix: misc
sl-1314 Apr 8, 2026
2f4b7fa
Merge branch 'main' into standalone_old_log_prob_support
sl-1314 Apr 8, 2026
15b0eb6
refactor: refactor old_log_prob_server to rollout replica
sl-1314 Apr 9, 2026
adf2e16
fix: revert unnessary change
sl-1314 Apr 9, 2026
8a01834
refacor: move old_log_prob_server to fully_async/
sl-1314 Apr 9, 2026
944c9f5
fix: resotre unnessary code
sl-1314 Apr 9, 2026
c41f236
fix: resotre unnessary code
sl-1314 Apr 9, 2026
5707a9c
simplify old_log_prob arch
sl-1314 Apr 9, 2026
59d3cd7
update: notes
sl-1314 Apr 9, 2026
9c97293
fix: misc
sl-1314 Apr 9, 2026
024b053
feat: support per tensor load
sl-1314 Apr 9, 2026
0ff02e4
update: misc
sl-1314 Apr 10, 2026
ea93e97
fix: remove redundant param
sl-1314 Apr 10, 2026
060b0d2
update: rename
sl-1314 Apr 13, 2026
3a60c6c
mv
sl-1314 Apr 13, 2026
3666fd1
fix: rename misc
sl-1314 Apr 13, 2026
2a0cb1f
fix: clean unused code
sl-1314 Apr 13, 2026
2adb3cb
fix: pre commit
sl-1314 Apr 13, 2026
08b92ce
fix: mv config
sl-1314 Apr 13, 2026
37ac30c
fix: misc
sl-1314 Apr 13, 2026
cedf3e4
update: remove unused code
sl-1314 Apr 13, 2026
a8e27ba
update: remove unused code
sl-1314 Apr 13, 2026
7d75b0a
Update verl/experimental/agent_loop/tool_agent_loop.py
sl-1314 Apr 13, 2026
0ff7de2
Merge branch 'standalone_old_log_prob_support' of https://github.com/…
sl-1314 Apr 13, 2026
999fb68
update: revise review comments
sl-1314 Apr 13, 2026
4b04e7b
update: rename variables
sl-1314 Apr 13, 2026
752a831
update: remove mbridge code
sl-1314 Apr 13, 2026
3be4d98
update: rename
sl-1314 Apr 14, 2026
98f7c4a
update: mv to extra_fields
sl-1314 Apr 14, 2026
354b026
update: remove other fix
sl-1314 Apr 14, 2026
c920d62
update: simplify validate code
sl-1314 Apr 14, 2026
00252c5
update: remove shutdown
sl-1314 Apr 14, 2026
3559cd8
update: misc
sl-1314 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"verl/checkpoint_engine", # checkpoint engine backend are device specific
"verl/utils/modelopt/megatron_qat_patch.py", # appear in torch.cuda.empty_cache()
"verl/models/mcore/patch.py", # checkpoint patch only on cuda
"verl/verl/workers/rollout/model_engine_server/model_engine_server.py", # appear in default device_name
]

# directory or file path must contain keyword "nccl"
Expand Down
17 changes: 17 additions & 0 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
"""
self.config = config
self._load_balancer = load_balancer_handle
self.model_engine_server_handle = None
self._server_id_to_handle: dict[str, ray.actor.ActorHandle] = dict(servers)

async def _acquire_server(self, request_id: str) -> tuple[str, ray.actor.ActorHandle]:
Expand Down Expand Up @@ -194,6 +195,8 @@ class AgentLoopOutput(BaseModel):
"""Response mask, 1 for LLM generated token, 0 for tool response token."""
response_logprobs: Optional[list[float]] = None
"""Log probabilities for the response tokens."""
response_oldlogprobs: Optional[list[float]] = None
"""Log probabilities calculated by standalone server for the response tokens."""
routed_experts: Optional[Any] = None
"""Routed experts for the total tokens."""
multi_modal_data: Optional[dict[str, Any]] = None
Expand Down Expand Up @@ -231,6 +234,8 @@ class _InternalAgentLoopOutput(AgentLoopOutput):
"""Padded log probabilities from teacher model for prompt/response tokens."""
teacher_ids: Optional[torch.Tensor] = None
"""Padded token ids corresponding to the teacher log probabilities."""
response_oldlogprobs: Optional[torch.Tensor] = None
"""Padded old log probabilities for the response tokens."""
routed_experts: Optional[torch.Tensor] = None
"""Padded routed experts for the total tokens."""
multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None
Expand Down Expand Up @@ -531,6 +536,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
sampling_params["top_p"] = config.val_kwargs.top_p
sampling_params["top_k"] = config.val_kwargs.top_k
sampling_params["temperature"] = config.val_kwargs.temperature
sampling_params["validate"] = True

# by default, we assume it's a single turn agent
if "agent_name" not in batch.non_tensor_batch:
Expand Down Expand Up @@ -675,6 +681,10 @@ async def _agent_loop_postprocess(self, output, validate, **kwargs) -> _Internal
if output.response_logprobs is not None:
pad_size = self.rollout_config.response_length - len(output.response_logprobs)
response_logprobs = torch.tensor(output.response_logprobs + [0.0] * pad_size).unsqueeze(0)
response_oldlogprobs = None
if output.response_oldlogprobs is not None:
pad_size = self.rollout_config.response_length - len(output.response_oldlogprobs)
response_oldlogprobs = torch.tensor(output.response_oldlogprobs + [0.0] * pad_size).unsqueeze(0)

response_mask = response_mask_output["input_ids"] * response_output["attention_mask"]
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)
Expand Down Expand Up @@ -750,6 +760,7 @@ async def _agent_loop_postprocess(self, output, validate, **kwargs) -> _Internal
response_mask=response_mask,
attention_mask=attention_mask,
response_logprobs=response_logprobs,
response_oldlogprobs=response_oldlogprobs,
routed_experts=routed_experts,
multi_modal_inputs=multi_modal_inputs,
multi_modal_data=output.multi_modal_data,
Expand Down Expand Up @@ -884,6 +895,10 @@ def _postprocess(
optional_outputs = {}
if inputs[0].response_logprobs is not None:
optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
if inputs[0].response_oldlogprobs is not None:
optional_outputs["server_old_log_probs"] = torch.cat(
[input.response_oldlogprobs for input in inputs], dim=0
)
if inputs[0].routed_experts is not None:
optional_outputs["routed_experts"] = torch.cat([input.routed_experts for input in inputs], dim=0)
if inputs[0].teacher_logprobs is not None and inputs[0].teacher_ids is not None:
Expand Down Expand Up @@ -1016,6 +1031,7 @@ def __init__(
self.worker_group = worker_group
self.rollout_resource_pool = rollout_resource_pool
self.reward_loop_worker_handles = reward_loop_worker_handles
self.model_engine_server_handle = None

self.teacher_model_manager = teacher_model_manager
self.distillation_enabled = is_distillation_enabled(self.config.get("distillation", None))
Expand Down Expand Up @@ -1132,6 +1148,7 @@ async def _init_agent_loop_workers(self):
teacher_servers,
teacher_load_balancer_handle,
self.reward_loop_worker_handles,
self.model_engine_server_handle,
)
)

Expand Down
1 change: 1 addition & 0 deletions verl/experimental/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
response_ids=output.token_ids[: self.response_length],
response_mask=response_mask[: self.response_length],
response_logprobs=output.log_probs[: self.response_length] if output.log_probs else None,
response_oldlogprobs=output.old_log_probs[: self.response_length] if output.old_log_probs else None,
routed_experts=(
output.routed_experts[: len(prompt_ids) + self.response_length]
if output.routed_experts is not None
Expand Down
11 changes: 10 additions & 1 deletion verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.response_ids: list[int] = []
self.response_mask: list[int] = []
self.response_logprobs: list[float] = []
self.response_oldlogprobs: list[float] = []
self.turn_scores: list[float] = []
self.tool_rewards: list[float] = []
self.user_turns = 0
Expand Down Expand Up @@ -192,6 +193,9 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
response_oldlogprobs=agent_data.response_oldlogprobs[: self.response_length]
if agent_data.response_oldlogprobs
else None,
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
routed_experts=agent_data.routed_experts,
Expand Down Expand Up @@ -246,7 +250,8 @@ async def _handle_generating_state(
agent_data.response_mask += [1] * len(agent_data.response_ids)
if output.log_probs:
agent_data.response_logprobs += output.log_probs

if output.old_log_probs:
agent_data.response_logprobs += output.old_log_probs
if output.routed_experts is not None:
agent_data.routed_experts = output.routed_experts

Expand Down Expand Up @@ -378,6 +383,8 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
agent_data.response_mask += [0] * len(response_ids)
if agent_data.response_logprobs:
agent_data.response_logprobs += [0.0] * len(response_ids)
if agent_data.response_oldlogprobs:
agent_data.response_oldlogprobs += [0.0] * len(response_ids)
agent_data.user_turns += 1
return AgentState.GENERATING

Expand Down Expand Up @@ -410,6 +417,8 @@ async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
agent_data.response_mask += [0] * len(response_ids)
if agent_data.response_logprobs:
agent_data.response_logprobs += [0.0] * len(response_ids)
if agent_data.response_oldlogprobs:
agent_data.response_oldlogprobs += [0.0] * len(response_ids)

# double check prompt
# Check termination condition
Expand Down
174 changes: 171 additions & 3 deletions verl/experimental/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ray
import torch
from omegaconf import DictConfig
from tensordict import TensorDict

from verl.experimental.agent_loop.agent_loop import (
AgentLoopManager,
Expand All @@ -29,6 +30,8 @@
from verl.experimental.teacher_loop import TeacherModelManager
from verl.protocol import DataProto
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
from verl.utils import tensordict_utils as tu
from verl.utils.model import compute_position_id_with_mask
from verl.utils.ray_utils import auto_await
from verl.utils.rollout_trace import (
rollout_trace_op,
Expand All @@ -44,6 +47,18 @@ class FullyAsyncLLMServerManager(AsyncLLMServerManager):
invisible to the AgentLoop.
"""

def __init__(
self,
config: DictConfig,
servers: list[tuple[str, ray.actor.ActorHandle]],
load_balancer_handle: ray.actor.ActorHandle,
model_engine_server_handle: ray.actor.ActorHandle = None,
tokenizer: Any = None,
):
super().__init__(config=config, servers=servers, load_balancer_handle=load_balancer_handle)
self.model_engine_server_handle = model_engine_server_handle
self.tokenizer = tokenizer

@rollout_trace_op
async def generate(
self,
Expand Down Expand Up @@ -74,10 +89,12 @@ async def generate(
elif "max_new_tokens" in sampling_params:
limit_key = "max_new_tokens"
original_max_tokens = sampling_params.get(limit_key) if limit_key else None
validate = sampling_params.pop("validate", None)

final_output = TokenOutput(
token_ids=[],
log_probs=[],
old_log_probs=[],
num_preempted=0,
)
min_global_steps, max_global_steps = None, None
Expand All @@ -91,11 +108,17 @@ async def generate(
image_data=image_data,
video_data=video_data,
)

current_prompt_ids = prompt_ids + final_output.token_ids
current_temperature = sampling_params.get("temperature", 1.0)
# Skip old log prob computation during validation
if not validate:
output = await self._compute_old_log_prob(output, current_prompt_ids, current_temperature)
# 2. merge output into final_output
final_output.token_ids.extend(output.token_ids)
if output.log_probs is not None:
final_output.log_probs.extend(output.log_probs)
if output.old_log_probs is not None:
final_output.old_log_probs.extend(output.old_log_probs)
if output.routed_experts is not None:
if final_output.routed_experts is None:
final_output.routed_experts = output.routed_experts
Expand All @@ -121,11 +144,108 @@ async def generate(
# 4. check stop reason
if output.stop_reason not in ("aborted", "abort") or not self.config.async_training.partial_rollout:
break
if validate: # restore for multi-turn use
sampling_params["validate"] = validate
final_output.extra_fields["global_steps"] = global_steps
final_output.extra_fields["min_global_steps"] = min_global_steps
final_output.extra_fields["max_global_steps"] = max_global_steps
return final_output

async def _compute_old_log_prob(self, output: TokenOutput, context_prompt_ids, temperature: float):
if self.model_engine_server_handle is None:
return output
# Convert TokenOutput -> fixed-shape TensorDict for OldLogProbServer.
if self.config.get("actor_rollout_ref"):
rollout_config = self.config.actor_rollout_ref.rollout
else:
rollout_config = self.config.rollout

# Prompt grows during partial-rollout/multi-turn. Reserve prompt slots for
# [original prompt + at most full response length] to keep shapes static.
max_prompt_len = int(rollout_config.prompt_length) + int(rollout_config.response_length)
max_response_len = int(rollout_config.response_length)

# Only recompute old_log_probs for newly generated tokens in this turn.
if len(output.token_ids) == 0:
output.old_log_probs = []
return output

prompt_len = len(context_prompt_ids)
response_len = len(output.token_ids)

if prompt_len > max_prompt_len:
raise ValueError(
f"prompt length {prompt_len} exceeds padded prompt length {max_prompt_len} "
"for old_log_prob recomputation"
)
if response_len > max_response_len:
print(
f"response length {response_len} exceeds padded response length {max_response_len} "
"for old_log_prob recomputation"
)
output.token_ids = output.token_ids[:max_response_len]
response_len = max_response_len

tokenizer = self.tokenizer
if tokenizer is None:
raise RuntimeError("tokenizer is required for old_log_prob recomputation padding")

original_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
prompt_output = tokenizer.pad(
{"input_ids": context_prompt_ids},
padding="max_length",
max_length=max_prompt_len,
return_tensors="pt",
return_attention_mask=True,
)
if prompt_output["input_ids"].dim() == 1:
prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0)
prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0)

tokenizer.padding_side = "right"
response_output = tokenizer.pad(
{"input_ids": output.token_ids},
padding="max_length",
max_length=max_response_len,
return_tensors="pt",
return_attention_mask=True,
)
if response_output["input_ids"].dim() == 1:
response_output["input_ids"] = response_output["input_ids"].unsqueeze(0)
response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0)
tokenizer.padding_side = original_padding_side

input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1)
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)
position_ids = compute_position_id_with_mask(attention_mask)
response_mask = response_output["attention_mask"]

data_td = TensorDict(
{
"prompts": prompt_output["input_ids"],
"responses": response_output["input_ids"],
"input_ids": input_ids,
"position_ids": position_ids,
"response_mask": response_mask,
"attention_mask": attention_mask,
"loss_mask": response_mask,
},
batch_size=[1],
)

tu.assign_non_tensor(
data_td,
temperature=temperature,
max_response_len=max_response_len,
)
result_td = await self.model_engine_server_handle.compute_log_prob.remote(data_td)

# Keep only valid response tokens; drop right-padding region.
log_probs_tensor = tu.get(result_td, "log_probs")
output.old_log_probs = log_probs_tensor[0, :response_len].tolist()
return output


@ray.remote
class FullyAsyncAgentLoopWorker(AgentLoopWorker):
Expand All @@ -137,8 +257,8 @@ def __init__(
teacher_servers: list[tuple[str, ray.actor.ActorHandle]] = None,
teacher_load_balancer_handle: ray.actor.ActorHandle = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
model_engine_server_handle: ray.actor.ActorHandle = None,
):
self.server_manager = FullyAsyncLLMServerManager(config, servers, load_balancer_handle)
super().__init__(
config,
servers,
Expand All @@ -147,6 +267,13 @@ def __init__(
teacher_load_balancer_handle,
reward_loop_worker_handles,
)
self.server_manager = FullyAsyncLLMServerManager(
config,
servers,
load_balancer_handle,
model_engine_server_handle,
tokenizer=self.tokenizer,
)


class FullyAsyncAgentLoopManager(AgentLoopManager):
Expand All @@ -159,10 +286,51 @@ def __init__(
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
super().__init__(config, worker_group, rollout_resource_pool, teacher_model_manager, reward_loop_worker_handles)
super().__init__(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

config,
worker_group,
rollout_resource_pool,
teacher_model_manager,
reward_loop_worker_handles,
)
if self.distillation_enabled:
raise NotImplementedError("Distillation is not implemented in FullyAsyncAgentLoopManager yet.")

async def _initialize_llm_servers(self):
"""Extend base class to also create ModelEngineReplica when configured."""
await super()._initialize_llm_servers()
if self.config.model_engine_server.enable_standalone:
await self._init_model_engine_replica()

async def _init_model_engine_replica(self):
"""Create ModelEngineReplica, call init_standalone, and append to rollout_replicas.

ModelEngineReplica.init_standalone() self-allocates a Ray resource pool,
spawns ModelEngineWorker actors, calls init_model(), and creates the
OldLogProbServer — all in one call, exactly like vLLM/SGLang replicas.

After this, self.model_engine_server_handle is set so that
_init_agent_loop_workers() passes it to every FullyAsyncAgentLoopWorker.
"""
from verl.workers.rollout.model_engine_server import ModelEngineReplica, ModelEngineWorker

replica = ModelEngineReplica(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be multi-DP?

replica_rank=len(self.rollout_replicas),
full_config=self.config,
worker_cls=ModelEngineWorker,
)
await replica.init_standalone()
self.rollout_replicas.append(replica)

# Expose the server handle so _init_agent_loop_workers passes it to workers.
self.model_engine_server_handle = replica.servers[0]

async def shutdown(self):
"""Shut down OldLogProbServer if one was created."""
for replica in self.rollout_replicas:
if hasattr(replica, "shutdown"):
await replica.shutdown()

@auto_await
async def generate_sequences_single(self, prompts: DataProto) -> DataProto:
"""Split input batch and dispatch to agent loop workers.
Expand Down
Loading