Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion verl/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,10 @@ def __init__(
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self, global_steps: int = None):
weights = self.checkpoint_engine.receive_weights()
await self.server_adapter.update_weights(weights, global_steps=global_steps)
pre_quantized_fp8 = self.rollout_config.get("trainer_quantize_fp8", False)
await self.server_adapter.update_weights(
weights, global_steps=global_steps, pre_quantized_fp8=pre_quantized_fp8
)

@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
def execute_checkpoint_engine(self, method: str, *args, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ quantization: null
# extra quantization information serialized in a config file, e.g. torchao_config.json
quantization_config_file: null

# When true, FP8 quantization is performed on the trainer side before weight sync,
# halving transfer bandwidth in disaggregated mode. Only effective when quantization=fp8.
trainer_quantize_fp8: false

# MTP configuration, reuse model configuration
mtp: ${oc.select:actor_rollout_ref.model.mtp, null}

Expand Down
35 changes: 35 additions & 0 deletions verl/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,38 @@ async def quant_weights_by_name(self, weights, dtype=torch.bfloat16):
logger.error(f"Failed to quantize {k}: {e}")
# If quantization fails, use original weights
yield (k, v)

def quant_weights_by_name_sync(self, weights, dtype=torch.bfloat16):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use ensure_async_iterator in Checkpoint engines to async for iterate weights, and eliminate this sync version.

"""Synchronous version for checkpoint engine send_weights path.

Checkpoint engines (NCCL/NIXL/HCCL) iterate weights with a standard
for loop, so the async version cannot be used.
"""
if isinstance(self.quant_config, dict):
weight_block_size = self.quant_config.get("weight_block_size")
else:
weight_block_size = getattr(self.quant_config, "weight_block_size", None)

if weight_block_size is None:
raise ValueError("weight_block_size not found in quant_config")

for k, v in weights:
if not self.should_quantize_param(k):
yield (k, v)
continue

try:
param_lp, param_scale = scaled_fp8_blockwise(
v.to(dtype),
weight_block_size=weight_block_size,
)
param_scale = param_scale.squeeze(-1)

yield (k, param_lp)
yield (k + "_scale_inv", param_scale)

del param_lp, param_scale

except Exception as e:
logger.error(f"Failed to quantize {k}: {e}")
yield (k, v)
2 changes: 2 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class RolloutConfig(BaseConfig):

quantization_config_file: Optional[str] = None

trainer_quantize_fp8: bool = False

enable_rollout_routing_replay: bool = False

enable_sleep_mode: bool = True
Expand Down
32 changes: 30 additions & 2 deletions verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,13 @@ async def update_weights(self, global_steps: int = None):
# 0. send_weights only for async training with disaggregated trainer and rollout
if self.config.rollout.checkpoint_engine.backend != "naive":
per_tensor_param, _ = self.actor.engine.get_per_tensor_param()
if self.config.rollout.get("trainer_quantize_fp8", False):
from verl.utils.fp8_utils import FP8QuantizerHelper

quant_config = {"weight_block_size": [128, 128]}
fp8_quantizer = FP8QuantizerHelper(quant_config)
per_tensor_param = fp8_quantizer.quant_weights_by_name_sync(per_tensor_param, dtype=torch.bfloat16)
logger.info("FP8 trainer-side quantization enabled (disaggregated)")
await self.checkpoint_engine.send_weights(per_tensor_param)
return

Expand All @@ -690,6 +697,17 @@ async def update_weights(self, global_steps: int = None):
layered_summon=self.layered_summon, base_sync_done=True
)

# FP8: quantize weights on trainer side before sending to rollout
_pre_quantized_fp8 = False
fp8_quantizer = None
if self.config.rollout.get("trainer_quantize_fp8", False) and (peft_config is None or self.peft_merge):
from verl.utils.fp8_utils import FP8QuantizerHelper

fp8_quantizer = FP8QuantizerHelper(self.rollout.model_config.hf_config.quantization_config)
per_tensor_param = fp8_quantizer.quant_weights_by_name(per_tensor_param, dtype=torch.bfloat16)
_pre_quantized_fp8 = True
logger.info("FP8 trainer-side quantization enabled")

do_lora_base_sync = False
if not self.peft_merge and peft_config is not None:
self.rollout.sleep_level = 1
Expand All @@ -700,12 +718,22 @@ async def update_weights(self, global_steps: int = None):
per_tensor_param_base, peft_config = self.actor.engine.get_per_tensor_param(
layered_summon=self.layered_summon, base_sync_done=False
)
if _pre_quantized_fp8:
per_tensor_param_base = fp8_quantizer.quant_weights_by_name(per_tensor_param_base, dtype=torch.bfloat16)
await self.rollout.update_weights(
per_tensor_param_base, peft_config=peft_config, base_sync_done=False, global_steps=global_steps
per_tensor_param_base,
peft_config=peft_config,
base_sync_done=False,
global_steps=global_steps,
pre_quantized_fp8=_pre_quantized_fp8,
)

await self.rollout.update_weights(
per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps
per_tensor_param,
peft_config=peft_config,
base_sync_done=True,
global_steps=global_steps,
pre_quantized_fp8=_pre_quantized_fp8,
)

log_gpu_memory_usage("After update_weights", logger=logger)
Expand Down
26 changes: 24 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,17 @@ async def rollout_mode(self):
)
aggressive_empty_cache(force_sync=True)

# FP8: quantize weights on trainer side before sending to rollout
_pre_quantized_fp8 = False
fp8_quantizer = None
if self.config.rollout.get("trainer_quantize_fp8", False) and (peft_config is None or self.peft_merge):
from verl.utils.fp8_utils import FP8QuantizerHelper

fp8_quantizer = FP8QuantizerHelper(self.rollout.model_config.hf_config.quantization_config)
per_tensor_param = fp8_quantizer.quant_weights_by_name(per_tensor_param, dtype=self._param_dtype)
_pre_quantized_fp8 = True
logger.info("FP8 trainer-side quantization enabled")

if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
log_gpu_memory_usage("After resume weights", logger=logger)
Expand All @@ -870,10 +881,21 @@ async def rollout_mode(self):
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in base_model_params.items()
)
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
if _pre_quantized_fp8:
per_tensor_base_params = fp8_quantizer.quant_weights_by_name(
per_tensor_base_params, dtype=self._param_dtype
)
await self.rollout.update_weights(
per_tensor_base_params, base_sync_done=False, pre_quantized_fp8=_pre_quantized_fp8
)
del base_model_params, per_tensor_base_params

await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
await self.rollout.update_weights(
per_tensor_param,
peft_config=peft_config,
base_sync_done=self.base_sync_done,
pre_quantized_fp8=_pre_quantized_fp8,
)
log_gpu_memory_usage("After update_weights", logger=logger)
del params, per_tensor_param
aggressive_empty_cache(force_sync=True)
Expand Down
6 changes: 3 additions & 3 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def update_weights(
await self._engine.load_lora_adapter_from_tensor(req)
else:
update_weights_bucket_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) << 20
if self.config.get("quantization", None) == "fp8":
if self.config.get("quantization", None) == "fp8" and not kwargs.get("pre_quantized_fp8", False):
from verl.utils.sglang.sglang_fp8_utils import SGLangFP8QuantizerHelper

logger.info("Convert bf16 weights to fp8 format before loading")
Expand All @@ -250,8 +250,8 @@ async def update_weights(
weights,
dtype=self.model_config.hf_config.dtype,
)
else:
weights = weights
elif kwargs.get("pre_quantized_fp8", False):
logger.info("Skipping FP8 quantization, weights pre-quantized on trainer side")

async for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes):
await sgl_update_weights(
Expand Down
Loading