Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions verl/checkpoint_engine/nccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def __init__(
self.group_name = group_name
self.rebuild_group = rebuild_group
self.rollout_dtype = rollout_dtype
self.send_buf = None
self.recv_buf = None

# start zeromq server for broadcasting bucket tensor metadata
self.is_master = is_master
Expand All @@ -126,13 +128,9 @@ def __init__(
self._start_zmq_server()

def prepare(self) -> MasterMetadata:
# For master process, use cupy instead of torch to avoid memory register error
# when `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please respect this comment.

if self.is_master:
self.send_buf = cp.zeros(self.bucket_size, dtype=cp.uint8)
self.recv_buf = cp.zeros(self.bucket_size, dtype=cp.uint8)
else:
if self.send_buf is None:
self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda")
if self.recv_buf is None:
self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda")

return MasterMetadata(zmq_ip=self.ip, zmq_port=self.listen_port) if self.is_master else None
Expand All @@ -145,11 +143,6 @@ def finalize(self):
self.rank = None
self.world_size = None

self.send_buf = None
self.recv_buf = None

torch.cuda.empty_cache()

@classmethod
def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]):
trainer_kwargs = {
Expand Down Expand Up @@ -274,7 +267,7 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
"dtype": weight.dtype,
"offset": offset,
}
send_buf[offset : offset + weight.nbytes] = cp.asarray(weight.view(-1).view(torch.uint8))
send_buf[offset : offset + weight.nbytes] = weight.view(-1).view(torch.uint8)
offset += weight.nbytes

# broadcast last bucket
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
hydra:
searchpath:
- file://verl/trainer/config
- pkg://verl.trainer.config

defaults:
- ppo_trainer
Expand Down
2 changes: 1 addition & 1 deletion verl/tools/sandbox_fusion_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def init_execution_pool(
if mode == PoolMode.ThreadMode:
return (
ray.remote(ExecutionWorker)
.options(max_concurrency=num_workers)
.options(name="sandbox-execution-pool", get_if_exists=True, max_concurrency=num_workers)
.remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)
)
else:
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/constants_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0143.html
"HCCL_HOST_SOCKET_PORT_RANGE": "auto",
"HCCL_NPU_SOCKET_PORT_RANGE": "auto",
"HSA_NO_SCRATCH_RECLAIM": "1",
},
}

Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dic
lines = []
for i in range(n):
entry = {k: v[i] for k, v in base_data.items()}
lines.append(json.dumps(entry, ensure_ascii=False))
lines.append(json.dumps(entry, ensure_ascii=False, default=str))

with open(filename, "w") as f:
f.write("\n".join(lines) + "\n")
Expand Down
6 changes: 4 additions & 2 deletions verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from verl.trainer.distillation import distillation_ppo_loss, is_distillation_enabled
from verl.utils import tensordict_utils as tu
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import get_device_name, is_npu_available, set_expandable_segments
from verl.utils.device import get_device_name, get_torch_device, is_npu_available, set_expandable_segments
from verl.utils.distributed import initialize_global_process_group_ray, set_numa_affinity
from verl.utils.flops_counter import FlopsCounter
from verl.utils.import_utils import import_external_libs
Expand Down Expand Up @@ -674,7 +674,9 @@ async def update_weights(self, global_steps: int = None):
# 0. send_weights only for async training with disaggregated trainer and rollout
if self.config.rollout.checkpoint_engine.backend != "naive":
per_tensor_param, _ = self.actor.engine.get_per_tensor_param()
await self.checkpoint_engine.send_weights(per_tensor_param)
per_tensor_param = list(per_tensor_param)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will materialize weight generator and gather all sharded weight into each GPU, causing cuda oom for large model.

get_torch_device().synchronize()
await self.checkpoint_engine.send_weights(iter(per_tensor_param))
return

set_expandable_segments(False)
Expand Down
25 changes: 21 additions & 4 deletions verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,23 @@ async def async_send_weights(self, weights):

def _init_socket(self):
"""Initialize ZMQ REQ socket and bind."""
if self.zmq_handle.startswith("ipc://"):
ipc_path = self.zmq_handle[len("ipc://") :]
try:
os.remove(ipc_path)
except FileNotFoundError:
pass
self.socket = self.zmq_context.socket(zmq.REQ)
self.socket.bind(self.zmq_handle)

def _init_buffer(self):
"""build communication buffer"""
"""build communication buffer, reuse if already allocated"""
if self.buffer is not None and not self.use_shm:
handle = reduce_tensor(self.buffer)
self.socket.send_pyobj(handle)
self.socket.recv()
return

buffer, shm = None, None
if not self.use_shm:
buffer = torch.empty(self.bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}")
Expand All @@ -181,17 +193,22 @@ def _init_buffer(self):
self.shm = shm

def _cleanup(self):
"""clean up"""
"""clean up socket but keep buffer for reuse"""
if self.socket is not None:
self.socket.close()
self.socket = None
del self.buffer
self.buffer = None
if self.zmq_handle.startswith("ipc://"):
ipc_path = self.zmq_handle[len("ipc://") :]
try:
os.remove(ipc_path)
except FileNotFoundError:
pass
if self.shm is not None:
self.shm.close()
self.shm.unlink()
del self.shm
self.shm = None
self.buffer = None
gc.collect()
get_torch_device().ipc_collect()
get_torch_device().empty_cache()
Expand Down
8 changes: 2 additions & 6 deletions verl/workers/rollout/vllm_rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config:

def _get_zmq_handle(self) -> str:
"""Get ZMQ handle for communication."""
if not hasattr(self, "device_uuid") or not self.device_uuid:
self.device_uuid = get_device_uuid(self.device.index)
return f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock"
return f"ipc:///tmp/rl-colocate-zmq-rank-{self.local_rank}.sock"


class vLLMOmniColocateWorkerExtension(_OmniWorkerBase):
Expand Down Expand Up @@ -331,9 +329,7 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config:

def _get_zmq_handle(self) -> str:
"""Get ZMQ handle for communication."""
if not hasattr(self, "device_uuid") or not self.device_uuid:
self.device_uuid = get_device_uuid(self.device.index)
return f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock"
return f"ipc:///tmp/rl-colocate-zmq-rank-{self.local_rank}.sock"


class SuppressSignalInThread:
Expand Down
17 changes: 9 additions & 8 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.sleep_level = VLLM_SLEEP_LEVEL

self.device_uuid = get_device_uuid(get_device_id())
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock"
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-rank-{rank % local_world_size}.sock"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will conflict for multiple vllm replicas in same node, e.g 2 replicas with TP=4 located on same node.


self.use_shm = not is_support_ipc()
if self.use_shm:
Expand Down Expand Up @@ -163,13 +163,14 @@ async def update_weights(
kwargs={**kwargs, "use_shm": self.use_shm},
)

bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes
sender = BucketedWeightSender(
zmq_handle=self.zmq_handle,
bucket_size_mb=bucket_size_mb,
use_shm=self.use_shm,
)
await sender.async_send_weights(weights)
if not hasattr(self, "_weight_sender") or self._weight_sender is None:
bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes
self._weight_sender = BucketedWeightSender(
zmq_handle=self.zmq_handle,
bucket_size_mb=bucket_size_mb,
use_shm=self.use_shm,
)
await self._weight_sender.async_send_weights(weights)

if future is not None:
await future
Expand Down