Skip to content

[Bug] FSDP2 CPUOffloadPolicy + state_dict() crashes with device mismatch during update_weights (non-LoRA full-weight training) #5995

@yifannnwu

Description

@yifannnwu

System Info

  • verl: latest main branch
  • Python: 3.12
  • PyTorch: 2.10
  • vLLM: 0.19.0
  • Ray: latest
  • Hardware: H200 (140GB each), multi-node (8-16 nodes)
  • CUDA: 12.x
  • Model: Qwen3.5

Reproduction

Running GRPO with FSDP2 hybrid engine (async vLLM rollout + FSDP2 training). Training step (forward + backward + optimizer step) completes successfully. Crash occurs at update_weights when verl calls get_per_tensor_param() to extract trained weights for syncing to the vLLM rollout engine.

Config (relevant parts):

actor_rollout_ref.actor.strategy=fsdp2
actor_rollout_ref.actor.fsdp_config.fsdp_size=64
actor_rollout_ref.actor.fsdp_config.param_offload=true
actor_rollout_ref.actor.fsdp_config.optimizer_offload=true
actor_rollout_ref.actor.fsdp_config.offload_policy=true
actor_rollout_ref.hybrid_engine=true
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.mode=async

Error

File "verl/workers/engine_workers.py", line 715, in update_weights
    per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param(...)
File "verl/workers/engine/fsdp/transformer_impl.py", line 774, in get_per_tensor_param
    params = self.module.state_dict()
...
RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.

Analysis

In transformer_impl.py:get_per_tensor_param():

  1. Line 751 calls load_fsdp_model_to_gpu(self.module), which for FSDP2 just calls model.to(device). However, with CPUOffloadPolicy, FSDP2 manages param placement automatically and model.to(device) does not actually move the offloaded shards back to GPU.

  2. Line 774 calls self.module.state_dict() — this crashes because internal PyTorch state_dict hooks encounter params whose storage is on CPU but the wrapper expects CUDA.

  3. Note: when offload_policy=True, verl sets _is_offload_param = False (line 390), so the manual offload/load logic in get_per_tensor_param is bypassed — verl relies on FSDP2 to handle it, but FSDP2's state_dict() doesn't handle the CPU↔CUDA transition correctly.

What I've tried

  1. Swapping order at line 791: Changed param.to(device).full_tensor()param.full_tensor().to(device). Doesn't help because crash is inside state_dict() itself (line 774), not in the generator expression.

  2. Using get_model_state_dict from torch.distributed.checkpoint.state_dict: Replaced self.module.state_dict() with get_model_state_dict(self.module, options=StateDictOptions(full_state_dict=False, cpu_offload=False)). Same crash — get_model_state_dict internally calls state_dict().

  3. All offload combinations: param_offload=true only, optimizer_offload=true only, offload_policy=true only — all crash at the same point.

Related

Expected behavior

get_per_tensor_param() should be able to extract weights from an FSDP2 model with CPUOffloadPolicy enabled, either by properly unsharding/materializing params before state_dict(), or by using an FSDP2-aware state dict extraction path.

Workaround

Disable all offload (param_offload=false, optimizer_offload=false, offload_policy=false) and use more nodes to fit in GPU memory. This works but wastes compute resources.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions