-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[fully_async, ckpt, rollout, trainer, tool, cfg] fix: ROCm async training compatibility for AMD MI300X #6002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| hydra: | ||
| searchpath: | ||
| - file://verl/trainer/config | ||
| - pkg://verl.trainer.config | ||
|
|
||
| defaults: | ||
| - ppo_trainer | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ | |
| from verl.trainer.distillation import distillation_ppo_loss, is_distillation_enabled | ||
| from verl.utils import tensordict_utils as tu | ||
| from verl.utils.config import omega_conf_to_dataclass | ||
| from verl.utils.device import get_device_name, is_npu_available, set_expandable_segments | ||
| from verl.utils.device import get_device_name, get_torch_device, is_npu_available, set_expandable_segments | ||
| from verl.utils.distributed import initialize_global_process_group_ray, set_numa_affinity | ||
| from verl.utils.flops_counter import FlopsCounter | ||
| from verl.utils.import_utils import import_external_libs | ||
|
|
@@ -674,7 +674,9 @@ async def update_weights(self, global_steps: int = None): | |
| # 0. send_weights only for async training with disaggregated trainer and rollout | ||
| if self.config.rollout.checkpoint_engine.backend != "naive": | ||
| per_tensor_param, _ = self.actor.engine.get_per_tensor_param() | ||
| await self.checkpoint_engine.send_weights(per_tensor_param) | ||
| per_tensor_param = list(per_tensor_param) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will materialize weight generator and gather all sharded weight into each GPU, causing cuda oom for large model. |
||
| get_torch_device().synchronize() | ||
| await self.checkpoint_engine.send_weights(iter(per_tensor_param)) | ||
| return | ||
|
|
||
| set_expandable_segments(False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,7 +95,7 @@ def __init__( | |
| self.sleep_level = VLLM_SLEEP_LEVEL | ||
|
|
||
| self.device_uuid = get_device_uuid(get_device_id()) | ||
| self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock" | ||
| self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-rank-{rank % local_world_size}.sock" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will conflict for multiple vllm replicas in same node, e.g 2 replicas with TP=4 located on same node. |
||
|
|
||
| self.use_shm = not is_support_ipc() | ||
| if self.use_shm: | ||
|
|
@@ -163,13 +163,14 @@ async def update_weights( | |
| kwargs={**kwargs, "use_shm": self.use_shm}, | ||
| ) | ||
|
|
||
| bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes | ||
| sender = BucketedWeightSender( | ||
| zmq_handle=self.zmq_handle, | ||
| bucket_size_mb=bucket_size_mb, | ||
| use_shm=self.use_shm, | ||
| ) | ||
| await sender.async_send_weights(weights) | ||
| if not hasattr(self, "_weight_sender") or self._weight_sender is None: | ||
| bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes | ||
| self._weight_sender = BucketedWeightSender( | ||
| zmq_handle=self.zmq_handle, | ||
| bucket_size_mb=bucket_size_mb, | ||
| use_shm=self.use_shm, | ||
| ) | ||
| await self._weight_sender.async_send_weights(weights) | ||
|
|
||
| if future is not None: | ||
| await future | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please respect this comment.