Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion verl/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from verl.utils.ray_utils import auto_await
from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig
from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class
from verl.workers.rollout.utils import get_minimum_bucket_size_mb


class TensorMeta(TypedDict):
Expand Down Expand Up @@ -265,7 +266,12 @@ def __init__(

self.server_adapter: BaseRollout = server_adapter
backend = self.rollout_config.checkpoint_engine.backend
bucket_size = self.rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20
# Auto-adjust bucket size based on embedding weight size
self.bucket_size_mb = get_minimum_bucket_size_mb(
hf_config=self.model_config.hf_config,
current_bucket_size_mb=self.config.checkpoint_engine.update_weights_bucket_megabytes,
)
bucket_size = self.bucket_size_mb << 20
engine_kwargs = self.rollout_config.checkpoint_engine.engine_kwargs.get(backend, {})
# If custom_backend_module is set, import it so plugins can register
# in CheckpointEngineRegistry before the backend is instantiated.
Expand Down
8 changes: 7 additions & 1 deletion verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
TrainingWorkerConfig,
)
from verl.workers.rollout.base import BaseRollout, get_rollout_class
from verl.workers.rollout.utils import get_minimum_bucket_size_mb
from verl.workers.utils.losses import ppo_loss

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -613,7 +614,12 @@ def init_model(self):
if "actor" in self.role:
checkpoint_engine_config = omega_conf_to_dataclass(self.config.rollout.checkpoint_engine)
backend = checkpoint_engine_config.backend
bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20
# Auto-adjust bucket size based on embedding weight size
bucket_size_mb = get_minimum_bucket_size_mb(
hf_config=model_config,
current_bucket_size_mb=checkpoint_engine_config.update_weights_bucket_megabytes,
)
bucket_size = bucket_size_mb << 20
engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {})
# If custom_backend_module is set, import it so plugins can register
# in CheckpointEngineRegistry before the backend is instantiated.
Expand Down
9 changes: 8 additions & 1 deletion verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SGLANG_LORA_NAME,
get_named_tensor_buckets,
)
from verl.workers.rollout.utils import get_minimum_bucket_size_mb

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -125,6 +126,12 @@ def __init__(
self.model_config.hf_config.quantization_config = fp8_block_quant_kwargs
self._engine: AsyncHttpServerAdapter = None

# Auto-adjust bucket size based on embedding weight size
self.bucket_size_mb = get_minimum_bucket_size_mb(
hf_config=self.model_config.hf_config,
current_bucket_size_mb=self.config.checkpoint_engine.update_weights_bucket_megabytes,
)

rank = int(os.environ["RANK"])
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size
Expand Down Expand Up @@ -240,7 +247,7 @@ async def update_weights(
# send http request
await self._engine.load_lora_adapter_from_tensor(req)
else:
update_weights_bucket_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) << 20
update_weights_bucket_bytes = int(self.bucket_size_mb) << 20
if self.config.get("quantization", None) == "fp8":
from verl.utils.sglang.sglang_fp8_utils import SGLangFP8QuantizerHelper

Expand Down
9 changes: 7 additions & 2 deletions verl/workers/rollout/trtllm_rollout/trtllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from verl.utils.net_utils import is_valid_ipv6_address
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.utils import ensure_async_iterator
from verl.workers.rollout.utils import ensure_async_iterator, get_minimum_bucket_size_mb

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -301,6 +301,11 @@ def __init__(
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
model_config.hf_config.quantization_config = fp8_block_quant_kwargs
super().__init__(config, model_config, device_mesh)
# Auto-adjust bucket size based on embedding weight size
self.bucket_size_mb = get_minimum_bucket_size_mb(
hf_config=self.model_config.hf_config,
current_bucket_size_mb=self.config.checkpoint_engine.update_weights_bucket_megabytes,
)
self._adapter = None
self.hybrid_device_mesh = None
self.gpu_id = None
Expand Down Expand Up @@ -440,7 +445,7 @@ async def update_weights(
if self.is_leader_rank:
await self._init_server_adapter()

total_available_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) * 1024 * 1024
total_available_bytes = int(self.bucket_size_mb) * 1024 * 1024

if self.config.get("quantization", None) == "fp8":
from verl.utils.trtllm.trtllm_fp8_utils import TRTLLMFP8QuantizerHelper
Expand Down
47 changes: 47 additions & 0 deletions verl/workers/rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
import logging
import math

import numpy as np
import uvicorn
Expand All @@ -33,6 +34,52 @@ def get_max_position_embeddings(hf_config) -> int:
return int(max_len)


def get_minimum_bucket_size_mb(hf_config, current_bucket_size_mb: int) -> int:
"""
Calculate the minimum required bucket size in MB based on the embedding weight size.

The embedding weight (embed_tokens) is typically the largest single weight tensor
in a model. The bucket size must be larger than any single weight tensor to avoid
AssertionError during weight transfer.

For multimodal models (e.g. Qwen3-VL), vocab_size and hidden_size may be nested
under text_config instead of the top-level config.

Args:
hf_config: HuggingFace model config object.
current_bucket_size_mb: Current bucket size in MB.

Returns:
Adjusted bucket size in MB, guaranteed to fit the embedding weight.
"""
# For multimodal models, vocab_size/hidden_size may be in text_config
text_config = getattr(hf_config, "text_config", None)
if text_config is not None:
vocab_size = getattr(text_config, "vocab_size", 0)
hidden_size = getattr(text_config, "hidden_size", 0)
else:
vocab_size = getattr(hf_config, "vocab_size", 0)
hidden_size = getattr(hf_config, "hidden_size", 0)

if not (vocab_size and hidden_size):
return current_bucket_size_mb

# embed_tokens: [vocab_size, hidden_size] in float32 = 4 bytes
embed_size_mb = math.ceil(vocab_size * hidden_size * 4 / 1024 / 1024)

if embed_size_mb <= current_bucket_size_mb:
return current_bucket_size_mb

# round up to next 512MB boundary
recommended_mb = (embed_size_mb // 512 + 1) * 512
logger.warning(
f"Embedding weight size ({embed_size_mb} MB) exceeds "
f"update_weights_bucket_megabytes ({current_bucket_size_mb} MB), "
f"automatically increasing to {recommended_mb} MB."
)
return recommended_mb


class _UvicornServerAutoPort(uvicorn.Server):
"""Uvicorn Server that reports the system-assigned port when port=0."""

Expand Down
9 changes: 7 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from verl.utils.device import get_device_id, is_support_ipc
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.utils import get_minimum_bucket_size_mb
from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import BucketedWeightSender
from verl.workers.rollout.vllm_rollout.utils import get_device_uuid

Expand Down Expand Up @@ -73,6 +74,11 @@ def __init__(
):
super().__init__(config, model_config, device_mesh)
self.server_handle: ray.actor.ActorHandle = None
# Auto-adjust bucket size based on embedding weight size
self.bucket_size_mb = get_minimum_bucket_size_mb(
hf_config=self.model_config.hf_config,
current_bucket_size_mb=self.config.checkpoint_engine.update_weights_bucket_megabytes,
)

rank = int(os.environ["RANK"])
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
Expand Down Expand Up @@ -163,10 +169,9 @@ async def update_weights(
kwargs={**kwargs, "use_shm": self.use_shm},
)

bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes
sender = BucketedWeightSender(
zmq_handle=self.zmq_handle,
bucket_size_mb=bucket_size_mb,
bucket_size_mb=self.bucket_size_mb,
use_shm=self.use_shm,
)
await sender.async_send_weights(weights)
Expand Down
Loading