-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[fully_async] feat: standalone log prob server (Model Engine Server) support #5990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sl-1314
wants to merge
41
commits into
verl-project:main
Choose a base branch
from
meituan-search:standalone_old_log_prob_support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 29 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
5b8d9b7
feat: add extra param sync groups for checkpoint engine
sl-1314 f9f0133
feature: support old_log_prob_server
sl-1314 7829343
fix: unify configs
sl-1314 2d48570
fix: add wait for
sl-1314 734cd3e
fix: restore debug code
sl-1314 70bccbd
fix: refator OldLogProbServer
sl-1314 bb9e96b
fix: run e2e exp
sl-1314 9c7f927
Revert "feat: add extra param sync groups for checkpoint engine"
sl-1314 a2c8d3f
fix: misc
sl-1314 2f4b7fa
Merge branch 'main' into standalone_old_log_prob_support
sl-1314 15b0eb6
refactor: refactor old_log_prob_server to rollout replica
sl-1314 adf2e16
fix: revert unnessary change
sl-1314 8a01834
refacor: move old_log_prob_server to fully_async/
sl-1314 944c9f5
fix: resotre unnessary code
sl-1314 c41f236
fix: resotre unnessary code
sl-1314 5707a9c
simplify old_log_prob arch
sl-1314 59d3cd7
update: notes
sl-1314 9c97293
fix: misc
sl-1314 024b053
feat: support per tensor load
sl-1314 0ff02e4
update: misc
sl-1314 ea93e97
fix: remove redundant param
sl-1314 060b0d2
update: rename
sl-1314 3a60c6c
mv
sl-1314 3666fd1
fix: rename misc
sl-1314 2a0cb1f
fix: clean unused code
sl-1314 2adb3cb
fix: pre commit
sl-1314 08b92ce
fix: mv config
sl-1314 37ac30c
fix: misc
sl-1314 cedf3e4
update: remove unused code
sl-1314 a8e27ba
update: remove unused code
sl-1314 7d75b0a
Update verl/experimental/agent_loop/tool_agent_loop.py
sl-1314 0ff7de2
Merge branch 'standalone_old_log_prob_support' of https://github.com/…
sl-1314 999fb68
update: revise review comments
sl-1314 4b04e7b
update: rename variables
sl-1314 752a831
update: remove mbridge code
sl-1314 3be4d98
update: rename
sl-1314 98f7c4a
update: mv to extra_fields
sl-1314 354b026
update: remove other fix
sl-1314 c920d62
update: simplify validate code
sl-1314 00252c5
update: remove shutdown
sl-1314 3559cd8
update: misc
sl-1314 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| import ray | ||
| import torch | ||
| from omegaconf import DictConfig | ||
| from tensordict import TensorDict | ||
|
|
||
| from verl.experimental.agent_loop.agent_loop import ( | ||
| AgentLoopManager, | ||
|
|
@@ -29,6 +30,8 @@ | |
| from verl.experimental.teacher_loop import TeacherModelManager | ||
| from verl.protocol import DataProto | ||
| from verl.single_controller.ray import RayResourcePool, RayWorkerGroup | ||
| from verl.utils import tensordict_utils as tu | ||
| from verl.utils.model import compute_position_id_with_mask | ||
| from verl.utils.ray_utils import auto_await | ||
| from verl.utils.rollout_trace import ( | ||
| rollout_trace_op, | ||
|
|
@@ -44,6 +47,18 @@ class FullyAsyncLLMServerManager(AsyncLLMServerManager): | |
| invisible to the AgentLoop. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: DictConfig, | ||
| servers: list[tuple[str, ray.actor.ActorHandle]], | ||
| load_balancer_handle: ray.actor.ActorHandle, | ||
| model_engine_server_handle: ray.actor.ActorHandle = None, | ||
| tokenizer: Any = None, | ||
| ): | ||
| super().__init__(config=config, servers=servers, load_balancer_handle=load_balancer_handle) | ||
| self.model_engine_server_handle = model_engine_server_handle | ||
| self.tokenizer = tokenizer | ||
|
|
||
| @rollout_trace_op | ||
| async def generate( | ||
| self, | ||
|
|
@@ -74,10 +89,12 @@ async def generate( | |
| elif "max_new_tokens" in sampling_params: | ||
| limit_key = "max_new_tokens" | ||
| original_max_tokens = sampling_params.get(limit_key) if limit_key else None | ||
| validate = sampling_params.pop("validate", None) | ||
sl-1314 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| final_output = TokenOutput( | ||
| token_ids=[], | ||
| log_probs=[], | ||
| old_log_probs=[], | ||
| num_preempted=0, | ||
| ) | ||
| min_global_steps, max_global_steps = None, None | ||
|
|
@@ -91,11 +108,17 @@ async def generate( | |
| image_data=image_data, | ||
| video_data=video_data, | ||
| ) | ||
|
|
||
| current_prompt_ids = prompt_ids + final_output.token_ids | ||
| current_temperature = sampling_params.get("temperature", 1.0) | ||
| # Skip old log prob computation during validation | ||
| if not validate: | ||
sl-1314 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| output = await self._compute_old_log_prob(output, current_prompt_ids, current_temperature) | ||
| # 2. merge output into final_output | ||
| final_output.token_ids.extend(output.token_ids) | ||
| if output.log_probs is not None: | ||
| final_output.log_probs.extend(output.log_probs) | ||
| if output.old_log_probs is not None: | ||
| final_output.old_log_probs.extend(output.old_log_probs) | ||
| if output.routed_experts is not None: | ||
| if final_output.routed_experts is None: | ||
| final_output.routed_experts = output.routed_experts | ||
|
|
@@ -121,11 +144,108 @@ async def generate( | |
| # 4. check stop reason | ||
| if output.stop_reason not in ("aborted", "abort") or not self.config.async_training.partial_rollout: | ||
| break | ||
| if validate: # restore for multi-turn use | ||
| sampling_params["validate"] = validate | ||
| final_output.extra_fields["global_steps"] = global_steps | ||
| final_output.extra_fields["min_global_steps"] = min_global_steps | ||
| final_output.extra_fields["max_global_steps"] = max_global_steps | ||
| return final_output | ||
|
|
||
| async def _compute_old_log_prob(self, output: TokenOutput, context_prompt_ids, temperature: float): | ||
| if self.model_engine_server_handle is None: | ||
| return output | ||
| # Convert TokenOutput -> fixed-shape TensorDict for OldLogProbServer. | ||
sl-1314 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.config.get("actor_rollout_ref"): | ||
| rollout_config = self.config.actor_rollout_ref.rollout | ||
| else: | ||
| rollout_config = self.config.rollout | ||
|
|
||
| # Prompt grows during partial-rollout/multi-turn. Reserve prompt slots for | ||
| # [original prompt + at most full response length] to keep shapes static. | ||
| max_prompt_len = int(rollout_config.prompt_length) + int(rollout_config.response_length) | ||
| max_response_len = int(rollout_config.response_length) | ||
|
|
||
| # Only recompute old_log_probs for newly generated tokens in this turn. | ||
| if len(output.token_ids) == 0: | ||
| output.old_log_probs = [] | ||
| return output | ||
|
|
||
| prompt_len = len(context_prompt_ids) | ||
| response_len = len(output.token_ids) | ||
|
|
||
| if prompt_len > max_prompt_len: | ||
sl-1314 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError( | ||
| f"prompt length {prompt_len} exceeds padded prompt length {max_prompt_len} " | ||
| "for old_log_prob recomputation" | ||
| ) | ||
| if response_len > max_response_len: | ||
| print( | ||
| f"response length {response_len} exceeds padded response length {max_response_len} " | ||
| "for old_log_prob recomputation" | ||
| ) | ||
| output.token_ids = output.token_ids[:max_response_len] | ||
| response_len = max_response_len | ||
|
|
||
| tokenizer = self.tokenizer | ||
| if tokenizer is None: | ||
| raise RuntimeError("tokenizer is required for old_log_prob recomputation padding") | ||
|
|
||
| original_padding_side = tokenizer.padding_side | ||
| tokenizer.padding_side = "left" | ||
| prompt_output = tokenizer.pad( | ||
| {"input_ids": context_prompt_ids}, | ||
| padding="max_length", | ||
| max_length=max_prompt_len, | ||
| return_tensors="pt", | ||
| return_attention_mask=True, | ||
| ) | ||
| if prompt_output["input_ids"].dim() == 1: | ||
| prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) | ||
| prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) | ||
|
|
||
| tokenizer.padding_side = "right" | ||
| response_output = tokenizer.pad( | ||
| {"input_ids": output.token_ids}, | ||
| padding="max_length", | ||
| max_length=max_response_len, | ||
| return_tensors="pt", | ||
| return_attention_mask=True, | ||
| ) | ||
| if response_output["input_ids"].dim() == 1: | ||
| response_output["input_ids"] = response_output["input_ids"].unsqueeze(0) | ||
| response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0) | ||
| tokenizer.padding_side = original_padding_side | ||
|
|
||
| input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1) | ||
| attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1) | ||
| position_ids = compute_position_id_with_mask(attention_mask) | ||
| response_mask = response_output["attention_mask"] | ||
|
|
||
| data_td = TensorDict( | ||
| { | ||
| "prompts": prompt_output["input_ids"], | ||
| "responses": response_output["input_ids"], | ||
| "input_ids": input_ids, | ||
| "position_ids": position_ids, | ||
| "response_mask": response_mask, | ||
| "attention_mask": attention_mask, | ||
| "loss_mask": response_mask, | ||
| }, | ||
| batch_size=[1], | ||
| ) | ||
|
|
||
| tu.assign_non_tensor( | ||
| data_td, | ||
| temperature=temperature, | ||
| max_response_len=max_response_len, | ||
| ) | ||
| result_td = await self.model_engine_server_handle.compute_log_prob.remote(data_td) | ||
|
|
||
| # Keep only valid response tokens; drop right-padding region. | ||
| log_probs_tensor = tu.get(result_td, "log_probs") | ||
| output.old_log_probs = log_probs_tensor[0, :response_len].tolist() | ||
| return output | ||
|
|
||
|
|
||
| @ray.remote | ||
| class FullyAsyncAgentLoopWorker(AgentLoopWorker): | ||
|
|
@@ -137,8 +257,8 @@ def __init__( | |
| teacher_servers: list[tuple[str, ray.actor.ActorHandle]] = None, | ||
| teacher_load_balancer_handle: ray.actor.ActorHandle = None, | ||
| reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, | ||
| model_engine_server_handle: ray.actor.ActorHandle = None, | ||
| ): | ||
| self.server_manager = FullyAsyncLLMServerManager(config, servers, load_balancer_handle) | ||
| super().__init__( | ||
| config, | ||
| servers, | ||
|
|
@@ -147,6 +267,13 @@ def __init__( | |
| teacher_load_balancer_handle, | ||
| reward_loop_worker_handles, | ||
| ) | ||
| self.server_manager = FullyAsyncLLMServerManager( | ||
sl-1314 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| config, | ||
| servers, | ||
| load_balancer_handle, | ||
| model_engine_server_handle, | ||
| tokenizer=self.tokenizer, | ||
| ) | ||
|
|
||
|
|
||
| class FullyAsyncAgentLoopManager(AgentLoopManager): | ||
|
|
@@ -159,10 +286,51 @@ def __init__( | |
| reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, | ||
| ): | ||
| self.agent_loop_workers_class = FullyAsyncAgentLoopWorker | ||
| super().__init__(config, worker_group, rollout_resource_pool, teacher_model_manager, reward_loop_worker_handles) | ||
| super().__init__( | ||
|
||
| config, | ||
| worker_group, | ||
| rollout_resource_pool, | ||
| teacher_model_manager, | ||
| reward_loop_worker_handles, | ||
| ) | ||
| if self.distillation_enabled: | ||
| raise NotImplementedError("Distillation is not implemented in FullyAsyncAgentLoopManager yet.") | ||
|
|
||
| async def _initialize_llm_servers(self): | ||
| """Extend base class to also create ModelEngineReplica when configured.""" | ||
| await super()._initialize_llm_servers() | ||
| if self.config.model_engine_server.enable_standalone: | ||
| await self._init_model_engine_replica() | ||
|
|
||
| async def _init_model_engine_replica(self): | ||
| """Create ModelEngineReplica, call init_standalone, and append to rollout_replicas. | ||
|
|
||
| ModelEngineReplica.init_standalone() self-allocates a Ray resource pool, | ||
| spawns ModelEngineWorker actors, calls init_model(), and creates the | ||
| OldLogProbServer — all in one call, exactly like vLLM/SGLang replicas. | ||
|
|
||
| After this, self.model_engine_server_handle is set so that | ||
| _init_agent_loop_workers() passes it to every FullyAsyncAgentLoopWorker. | ||
| """ | ||
| from verl.workers.rollout.model_engine_server import ModelEngineReplica, ModelEngineWorker | ||
|
|
||
| replica = ModelEngineReplica( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be multi-DP? |
||
| replica_rank=len(self.rollout_replicas), | ||
| full_config=self.config, | ||
| worker_cls=ModelEngineWorker, | ||
| ) | ||
| await replica.init_standalone() | ||
| self.rollout_replicas.append(replica) | ||
|
|
||
| # Expose the server handle so _init_agent_loop_workers passes it to workers. | ||
| self.model_engine_server_handle = replica.servers[0] | ||
|
|
||
| async def shutdown(self): | ||
| """Shut down OldLogProbServer if one was created.""" | ||
| for replica in self.rollout_replicas: | ||
| if hasattr(replica, "shutdown"): | ||
| await replica.shutdown() | ||
|
|
||
| @auto_await | ||
| async def generate_sequences_single(self, prompts: DataProto) -> DataProto: | ||
| """Split input batch and dispatch to agent loop workers. | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.