Skip to content

[fully_async, ckpt, rollout, trainer, tool, cfg] fix: ROCm async training compatibility for AMD MI300X#6002

Open
xiaohong42 wants to merge 1 commit intoverl-project:mainfrom
xiaohong42:fix/rocm-async-training-compatibility
Open

[fully_async, ckpt, rollout, trainer, tool, cfg] fix: ROCm async training compatibility for AMD MI300X#6002
xiaohong42 wants to merge 1 commit intoverl-project:mainfrom
xiaohong42:fix/rocm-async-training-compatibility

Conversation

@xiaohong42
Copy link
Copy Markdown

@xiaohong42 xiaohong42 commented Apr 14, 2026

What does this PR do?

Fix multiple issues that prevent fully async FSDP2 training from working on AMD ROCm platforms (MI300X series).

Environment:

  • AMD Instinct MI308X (8× GPU, 192 GB HBM each), ROCm 7.2.0, PyTorch 2.10.0+rocm7.2.0, vLLM v0.19.1rc0
  • Cross-validated on NVIDIA H20 with CUDA, vLLM 0.18.2 (latest verl main + this patch): no regression observed up to step 86, where a pre-existing bug (AttributeError: 'list' object has no attribute 'dim' in agent_loop.py:696) is hit — this bug exists with or without this patch and is unrelated to these changes.

Training curves (MI308X vs H20) and training script

dapo_7b_fully_async.sh
qwen2 5_7b_fully_async

Checklist Before Starting

Test

Validated by full async FSDP2 DAPO/GRPO RL + ReTool training on AMD MI308X:

  • 250+ global steps completed, 60+ weight synchronization cycles without OOM or deadlock
  • Cross-validated on NVIDIA H20: training runs normally with this patch applied (no regression)
    • H20 (CUDA) cross-validation note:

Applied this patch on latest verl main (commit 9b54564) + vLLM 0.18.2 on NVIDIA H20. Training ran normally up to step 86 / global_step 344, where a pre-existing bug in agent_loop.py:698 is hit:

File "verl/experimental/agent_loop/agent_loop.py", line 698, in _agent_loop_postprocess
if response_mask_output["input_ids"].dim() == 1:
AttributeError: 'list' object has no attribute 'dim'
This is caused by tokenizer.pad() returning a Python list instead of a torch.Tensor for response_mask in certain edge cases, even with return_tensors="pt". This bug exists on the current main branch with or without this patch — it is not introduced by any changes in this PR. The file agent_loop.py is not modified in this PR.

MI308X (ROCm 7.2): 250+ global steps, 60+ weight syncs completed without OOM or deadlock. The same agent_loop.py bug was also encountered on MI308X at a later step, confirming it is platform-independent.

API and Usage Example

No API changes. All fixes are internal implementation details.

Design & Code Changes

  1. NCCL checkpoint engine: unify buffers to torch tensors (nccl_checkpoint_engine.py)

  2. Add HSA_NO_SCRATCH_RECLAIM env var (constants_ppo.py)

    • Required by AMD RCCL on MI300X; without it, FSDP initialization fails with ncclSystemError
  3. Fix numpy.bool_ JSON serialization (ray_trainer.py)

    • Add default=str fallback for json.dumps since numpy 2.x bool_ is no longer a Python bool subclass
  4. Materialize generator before send_weights (engine_workers.py)

    • get_per_tensor_param() returns a generator containing full_tensor() calls that trigger FSDP all_gather
    • Lazy consumption causes rank-misaligned collective calls → deadlock
    • Fix: list() to materialize + torch.cuda.synchronize() before sending
  5. ZMQ IPC handle: use rank instead of GPU UUID (vllm_rollout.py, utils.py)

    • On ROCm, CheckpointEngineWorker and vLLM worker see different GPU UUIDs due to different CUDA_VISIBLE_DEVICES/HIP_VISIBLE_DEVICES settings
    • Use deterministic rank number instead
  6. Clean up stale ZMQ IPC socket files (bucketed_weight_transfer.py)

    • Remove leftover /tmp/rl-colocate-zmq-rank-*.sock files before bind() and after close() to prevent Address already in use on restart
  7. Fix Hydra searchpath (fully_async_ppo_trainer.yaml)

    • Use pkg://verl.trainer.config instead of file://verl/trainer/config for editable installs
  8. Sandbox Ray actor reuse (sandbox_fusion_tools.py)

    • Add name and get_if_exists=True to prevent duplicate ExecutionWorker actor creation
  9. Persist weight sync buffers to prevent OOM (nccl_checkpoint_engine.py, vllm_rollout.py, bucketed_weight_transfer.py)

    • On ROCm/HIP, torch.cuda.empty_cache() does not effectively return physical memory to the system
    • Repeated alloc/free of large buffers (NCCL 4 GB + IPC 2 GB) causes HIP memory fragmentation → OOM after ~10 weight sync cycles
    • Fix: allocate once and reuse across sync iterations

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation. The official documents will be compiled after the merger.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: These fixes target ROCm-specific runtime behavior (HIP memory management, RCCL env vars, GPU UUID mismatch) that cannot be reproduced in CI without AMD GPU hardware.
  • Once your PR is ready for CI, send a message in the ci-request channel in the verl Slack workspace. (If not accessible, please try the Feishu group (飞书群).)
  • If your PR is related to the recipe submodule, please also update the reference to the submodule commit via git submodule update --remote or cd recipe && git pull origin main.

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


root seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes weight transfer by implementing buffer reuse in the BucketedWeightSender and persisting the weight sender instance. It also replaces CuPy with Torch in the NCCL checkpoint engine and updates ZMQ handle naming. Review feedback identifies a critical bug where ZMQ handles are inconsistently named using global ranks on the sender side and local ranks on the receiver side, which will cause failures in multi-node environments. Additionally, a suggestion was made to explicitly nullify buffers when unlinking shared memory to prevent potential use-after-free issues.

@xiaohong42 xiaohong42 force-pushed the fix/rocm-async-training-compatibility branch from 31ea03f to 92cdd55 Compare April 14, 2026 07:33
Fix multiple issues that prevent fully async FSDP2 training from working
on AMD ROCm platforms (tested on MI308X with ROCm 7.2, also verified
compatible on H20 with CUDA):

1. Unify NCCL checkpoint engine buffers to torch tensors, replacing cupy
   which causes HIP stream synchronization issues on ROCm
2. Add HSA_NO_SCRATCH_RECLAIM env var required by AMD RCCL
3. Fix numpy.bool_ JSON serialization with numpy 2.x
4. Materialize generator before send_weights to prevent FSDP all_gather
   deadlock across ranks
5. Use deterministic rank-based ZMQ IPC handles instead of GPU UUID which
   differs between checkpoint engine and vLLM workers on ROCm
6. Clean up stale ZMQ IPC socket files to prevent bind failures on restart
7. Fix Hydra searchpath to use pkg:// instead of file:// for editable installs
8. Add get_if_exists to sandbox Ray actor to prevent duplicate creation
9. Persist weight sync buffers (NCCL + IPC) to prevent HIP memory
   fragmentation OOM from repeated alloc/free cycles

Made-with: Cursor
@xiaohong42 xiaohong42 force-pushed the fix/rocm-async-training-compatibility branch from 92cdd55 to d01bdc3 Compare April 14, 2026 07:43
if self.config.rollout.checkpoint_engine.backend != "naive":
per_tensor_param, _ = self.actor.engine.get_per_tensor_param()
await self.checkpoint_engine.send_weights(per_tensor_param)
per_tensor_param = list(per_tensor_param)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will materialize weight generator and gather all sharded weight into each GPU, causing cuda oom for large model.


self.device_uuid = get_device_uuid(get_device_id())
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock"
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-rank-{rank % local_world_size}.sock"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will conflict for multiple vllm replicas in same node, e.g 2 replicas with TP=4 located on same node.


def prepare(self) -> MasterMetadata:
# For master process, use cupy instead of torch to avoid memory register error
# when `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please respect this comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants