From 233b7a5316ee8d475e56716e6723c096e5466d54 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sat, 27 Sep 2025 03:33:02 +0000 Subject: [PATCH 01/35] tweak task template --- pipelinerl/domains/math/rollouts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 41a61021..171cd945 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -43,7 +43,7 @@ async def generate_math_rollout( messages = [] if cfg.actor.system_prompt: messages.append({"role": "system", "content": cfg.actor.system_prompt}) - messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])}) + messages.append({"role": "user", "content": f"{problem['task']}\n{cfg.actor.task_prompt}"}) prompt = Prompt(messages=messages) time_start = time.time() From 6adb0e8c446488abbd23a8176674cb56f03e8616 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sat, 27 Sep 2025 03:34:02 +0000 Subject: [PATCH 02/35] tweak task template --- pipelinerl/domains/math/rollouts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 171cd945..8b88d877 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -43,7 +43,7 @@ async def generate_math_rollout( messages = [] if cfg.actor.system_prompt: messages.append({"role": "system", "content": cfg.actor.system_prompt}) - messages.append({"role": "user", "content": f"{problem['task']}\n{cfg.actor.task_prompt}"}) + messages.append({"role": "user", "content": f"{problem['task']} \n{cfg.actor.task_prompt}"}) prompt = Prompt(messages=messages) time_start = time.time() From ef157929759adea9fd95d3f9cbbaf87cf09c1283 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sat, 27 Sep 2025 23:38:47 +0000 Subject: [PATCH 03/35] gspo --- pipelinerl/finetune/rl/__init__.py | 57 ++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index e74b9a0b..f35a9549 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -39,7 +39,7 @@ class RLConfig(BaseModel): policy_loss: str = Field( default="ppo", description="Policy Loss to use for RL", - choices=["ppo", "reinforce"], + choices=["ppo", "reinforce", "gspo"], ) use_advantages: bool = Field( default=True, @@ -280,17 +280,46 @@ def rl_step( clamp_log_ratio_new_old_indicators = ratio_new_old > 1 + config.epsilon ratio_new_old = torch.clamp(ratio_new_old, 0, 1 + config.epsilon) policy_loss = new_logprobs * log_p_weights * ratio_new_old.detach() + case "gspo": + if segments is None: + raise ValueError("GSPO loss requires packed sequences with segments") + # Aggregate per-sequence means over valid (labeled) tokens only. + # Skip sequences with zero labeled tokens to avoid NaNs from mean([]). + group_log_ratio_new_old: list[torch.Tensor] = [] + group_advantages: list[torch.Tensor] = [] + for i in range(num_sequences): + start, end = segments[i] + group_mask_id = masks_shifted[0, start:end] + if not torch.any(group_mask_id): + continue # no output tokens in this segment; skip it entirely + grp_lrn = log_ratio_new_old[0, start:end][group_mask_id] + grp_adv = advantages[0, start:end][group_mask_id] + # means are well-defined because we ensured at least one element + group_log_ratio_new_old.append(grp_lrn.mean()) + group_advantages.append(grp_adv.mean()) + if len(group_log_ratio_new_old) == 0: + # No valid groups in this batch; define zero loss on current step + policy_loss_total = torch.zeros((), device=new_logprobs.device, dtype=new_logprobs.dtype) + else: + group_ratio_new_old = torch.exp(torch.stack(group_log_ratio_new_old)).unsqueeze(1).unsqueeze(2) + group_advantages_t = torch.stack(group_advantages).unsqueeze(1).unsqueeze(2) + surr1 = group_ratio_new_old * group_advantages_t + clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) + clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old + surr2 = clamped_group_ratio * group_advantages_t + policy_loss_total = torch.min(surr1, surr2).sum() case _: raise ValueError(f"Unknown algorithm {config.policy_loss}") # combine loss components - loss = policy_loss - kl_coef * approx_kl + entropy_bonus_coef * entropy # 1 x (BxL) x 1 - assert loss.shape == tokens_weights.shape, ( - f"Loss shape {loss.shape} does not match example weights shape {tokens_weights.shape}" - ) - loss = loss * tokens_weights # 1 x (BxL) x 1 + if config.policy_loss != "gspo": + loss = policy_loss - kl_coef * approx_kl + entropy_bonus_coef * entropy # 1 x (BxL) x 1 + assert loss.shape == tokens_weights.shape, ( + f"Loss shape {loss.shape} does not match example weights shape {tokens_weights.shape}" + ) + loss = loss * tokens_weights # 1 x (BxL) x 1 - policy_loss_total = -sum_sum(loss, masks_shifted, segments) + policy_loss_total = -sum_sum(loss, masks_shifted, segments) if has_value_head: # Get the value predictions @@ -341,9 +370,9 @@ def rl_step( "kl_new_old": sum_sum(approx_kl_new_old / num_labels_in_seq, masks_shifted, segments).item(), "max_kl": approx_kl[masks_shifted].max().item(), "min_kl": approx_kl[masks_shifted].min().item(), - "policy_loss": sum_sum(policy_loss / num_labels_in_seq, masks_shifted, segments).item(), - "surr1": sum_sum(surr1 / num_labels_in_seq, masks_shifted, segments).item(), - "surr2": sum_sum(surr2 / num_labels_in_seq, masks_shifted, segments).item(), + #"policy_loss": sum_sum(policy_loss / num_labels_in_seq, masks_shifted, segments).item(), + #"surr1": sum_sum(surr1 / num_labels_in_seq, masks_shifted, segments).item(), + #"surr2": sum_sum(surr2 / num_labels_in_seq, masks_shifted, segments).item(), "ratio_new_old": sum_sum(ratio_new_old / num_labels_in_seq, masks_shifted, segments).item(), "ratio_new_old_sum": sum_sum(ratio_new_old, masks_shifted, segments).item(), "ratio_new_old_squared_sum": sum_sum( # useful to estimate the ESS @@ -354,10 +383,10 @@ def rl_step( "clamp_log_ratio_ref_new_indicator": sum_sum( clamp_log_ratio_ref_new_indicators / num_labels_in_seq, masks_shifted, segments ).item(), - "clamp_log_ratio_new_old_indicator": sum_sum( - clamp_log_ratio_new_old_indicators / num_labels_in_seq, masks_shifted, segments - ).item(), - "num_nans": torch.isnan(loss).sum().item(), + #"clamp_log_ratio_new_old_indicator": sum_sum( + # clamp_log_ratio_new_old_indicators / num_labels_in_seq, masks_shifted, segments + #).item(), + #"num_nans": torch.isnan(loss).sum().item(), "token_weight": sum_sum(tokens_weights / num_labels_in_seq, masks_shifted, segments).item(), "max_token_weight": tokens_weights[masks_shifted].max().item(), "min_token_weight": tokens_weights[masks_shifted].min().item(), From 2b97222cccb29d102d9386566087d4e01aad8ed2 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 00:28:49 +0000 Subject: [PATCH 04/35] sentinel batch --- pipelinerl/finetune/rl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index f35a9549..d2f29f9b 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -299,7 +299,7 @@ def rl_step( group_advantages.append(grp_adv.mean()) if len(group_log_ratio_new_old) == 0: # No valid groups in this batch; define zero loss on current step - policy_loss_total = torch.zeros((), device=new_logprobs.device, dtype=new_logprobs.dtype) + policy_loss_total = 0 * ratio_new_old.sum() else: group_ratio_new_old = torch.exp(torch.stack(group_log_ratio_new_old)).unsqueeze(1).unsqueeze(2) group_advantages_t = torch.stack(group_advantages).unsqueeze(1).unsqueeze(2) From 9e31adf0b44882c5c296b7a083de7c1e82353394 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 00:49:17 +0000 Subject: [PATCH 05/35] typo --- conf/math.yaml | 3 +-- pipelinerl/finetune/rl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/conf/math.yaml b/conf/math.yaml index 069aa96b..53dd2351 100644 --- a/conf/math.yaml +++ b/conf/math.yaml @@ -5,8 +5,7 @@ defaults: actor: rollout_policy: pipelinerl.domains.math.generate_math_rollout system_prompt: Please reason step by step, and put your final answer within \boxed{}. - task_template: |- - {task} + task_prompt: "" environment: _target_: pipelinerl.domains.math.MathEnvironment dataset_loader: pipelinerl.domains.math.load_datasets diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index d2f29f9b..16ce0db3 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -307,7 +307,7 @@ def rl_step( clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old surr2 = clamped_group_ratio * group_advantages_t - policy_loss_total = torch.min(surr1, surr2).sum() + policy_loss_total = -torch.min(surr1, surr2).sum() case _: raise ValueError(f"Unknown algorithm {config.policy_loss}") From e7e34e1822ea6217b052c81195b27e85b5265075 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 03:21:27 +0000 Subject: [PATCH 06/35] seq parallel --- pipelinerl/finetune/data.py | 4 ++ pipelinerl/finetune/rl/__init__.py | 39 +++++------ pipelinerl/finetune/rl/utils.py | 102 ++++++++++++++++++++++++++++- pipelinerl/finetune/types.py | 6 +- pipelinerl/finetune/utils.py | 2 + 5 files changed, 127 insertions(+), 26 deletions(-) diff --git a/pipelinerl/finetune/data.py b/pipelinerl/finetune/data.py index 9065b9e8..066f328f 100644 --- a/pipelinerl/finetune/data.py +++ b/pipelinerl/finetune/data.py @@ -239,6 +239,7 @@ def collate_packed( "labels": torch.empty(1, total_length, dtype=torch.long), "attention_mask": torch.ones(1, total_length, dtype=torch.long), # initialize to 1s "position_ids": torch.empty(1, total_length, dtype=torch.long), + "segment_ids": torch.empty(1, total_length, dtype=torch.long), } # initialize lists for extra keys @@ -254,6 +255,9 @@ def collate_packed( # use arange to fill position_ids base_tensors["position_ids"][0, start_idx:end_idx] = torch.arange(seq_len) + # set per-token segment ids to this sequence index + if seq_len > 0: + base_tensors["segment_ids"][0, start_idx:end_idx] = i # process labels example_labels = torch.tensor(example["labels"], dtype=torch.long) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 16ce0db3..c2438f9a 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -11,6 +11,7 @@ from datasets import Dataset from transformers import PreTrainedModel from pipelinerl.finetune.types import PipelineBatchEncoding +from pipelinerl.finetune.rl.utils import per_segment_sums from .utils import ( sum_sum, @@ -133,6 +134,7 @@ def rl_step( current_step: int, max_step: int, config: RLConfig, + seq_parallel_group=None, ) -> tuple[torch.Tensor, dict[str, float]]: """ Perform a single RL step on the model using the given batch and config. @@ -285,29 +287,20 @@ def rl_step( raise ValueError("GSPO loss requires packed sequences with segments") # Aggregate per-sequence means over valid (labeled) tokens only. # Skip sequences with zero labeled tokens to avoid NaNs from mean([]). - group_log_ratio_new_old: list[torch.Tensor] = [] - group_advantages: list[torch.Tensor] = [] - for i in range(num_sequences): - start, end = segments[i] - group_mask_id = masks_shifted[0, start:end] - if not torch.any(group_mask_id): - continue # no output tokens in this segment; skip it entirely - grp_lrn = log_ratio_new_old[0, start:end][group_mask_id] - grp_adv = advantages[0, start:end][group_mask_id] - # means are well-defined because we ensured at least one element - group_log_ratio_new_old.append(grp_lrn.mean()) - group_advantages.append(grp_adv.mean()) - if len(group_log_ratio_new_old) == 0: - # No valid groups in this batch; define zero loss on current step - policy_loss_total = 0 * ratio_new_old.sum() - else: - group_ratio_new_old = torch.exp(torch.stack(group_log_ratio_new_old)).unsqueeze(1).unsqueeze(2) - group_advantages_t = torch.stack(group_advantages).unsqueeze(1).unsqueeze(2) - surr1 = group_ratio_new_old * group_advantages_t - clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) - clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old - surr2 = clamped_group_ratio * group_advantages_t - policy_loss_total = -torch.min(surr1, surr2).sum() + lrn_sum, adv_sum, tok_count = per_segment_sums( + batch.segment_ids, + masks_shifted, + log_ratio_new_old, + advantages, + seq_parallel_group=seq_parallel_group, + ) + group_ratio_new_old = torch.exp(lrn_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2) + group_advantages_t = (adv_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2) + surr1 = group_ratio_new_old * group_advantages_t + clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) + clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old + surr2 = clamped_group_ratio * group_advantages_t + policy_loss_total = -torch.min(surr1, surr2).sum() case _: raise ValueError(f"Unknown algorithm {config.policy_loss}") diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index 9e4b49b3..3cc3265f 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -3,7 +3,7 @@ import numpy as np import torch from datasets import Dataset - +from torch import distributed as dist def aggregate_rl_stats(rl_stats: dict, num_samples: int): avg_rl_stats: dict[str, float] = {} @@ -96,3 +96,103 @@ def replace_dataset_column(dataset: Dataset, column_name: str, new_column: List[ dataset = dataset.add_column(name=column_name, column=new_column) # type: ignore return dataset + + + +def per_segment_sums( + segment_ids: torch.LongTensor, + masks_shifted: torch.Tensor, + log_ratio_new_old: torch.Tensor, + advantages: torch.Tensor, + seq_parallel_group=None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Differentiable version of per-segment sums using scatter_add instead of bincount. + + Args: + segment_ids: [1, L] integer segment id per token (packed batches). Should align with labels positions. + masks_shifted: [1, L-1] boolean mask for valid (non -100) labels excluding first token. + log_ratio_new_old: [1, L-1] tensor. + advantages: [1, L-1] tensor. + n_segments: total number of segments. + seq_parallel_group: optional torch.distributed group for seq-parallel; if provided, results are all-reduced. + + Returns: + (log_ratio_sum_per_seg, advantages_sum_per_seg, token_count_per_seg), each shaped [n_segments]. + """ + if segment_ids is None: + raise ValueError("segment_ids must be provided for per-segment reductions") + + # Expect [1, L] -> align to shifted tensors [1, L-1] + if segment_ids.dim() != 2 or segment_ids.shape[0] != 1: + raise ValueError(f"Expected segment_ids of shape [1, L], got {tuple(segment_ids.shape)}") + + # Slice and unify device/dtypes + seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long) + local_max = seg.max().to(torch.int64) + + if seq_parallel_group is None or not dist.is_available() or not dist.is_initialized(): + n_segments = int(local_max.item()) + 1 + else: + global_max = local_max.clone() + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=seq_parallel_group) + n_segments = int(global_max.item()) + 1 + + mask = masks_shifted[:, :seg.numel()].contiguous().squeeze(0) + lrn = log_ratio_new_old[:, :seg.numel()].contiguous().squeeze(0) + adv = advantages[:, :seg.numel()].contiguous().squeeze(0) + + # Put everything on same device + device = lrn.device + seg = seg.to(device=device) + mask = mask.to(device=device, dtype=lrn.dtype) # float mask is fine for weighting + adv = adv.to(device=device, dtype=lrn.dtype) + # lrn already on device + + # Consider only VALID tokens before indexing + # Important: this prevents out-of-bounds reads from indices you intended to ignore + valid = (mask != 0) + if valid.ndim != 1 or valid.shape[0] != seg.shape[0]: + raise ValueError("Mask shape mismatch after alignment with segment_ids.") + + if valid.any(): + seg_v = seg[valid] # indices actually used + w_v = mask[valid] # weights for counts + lrn_v = lrn[valid] + adv_v = adv[valid] + + # Range check BEFORE scatter to produce a clean error + smin = int(seg_v.min()) + smax = int(seg_v.max()) + if smin < 0 or smax >= n_segments: + raise IndexError( + f"per_segment_sums_diff_safe: segment index out of bounds. " + f"min(seg)={smin}, max(seg)={smax}, n_segments={n_segments}. " + "Likely causes: (1) n_segments too small (compute after packing), " + "(2) off-by-one when dropping the first token, " + "(3) segment ids are global across workers but you passed local n_segments." + ) + + # Allocate outputs + token_count = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + lrn_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + adv_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + + # index_add_ is equivalent to scatter_add_ for 1D reductions and can be clearer + token_count.index_add_(0, seg_v, w_v) + lrn_sum.index_add_(0, seg_v, lrn_v * w_v) + adv_sum.index_add_(0, seg_v, adv_v * w_v) + else: + # No valid tokens: return zeros + token_count = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + lrn_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + adv_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + + # Optional all-reduce across sequence-parallel group + if seq_parallel_group is not None and dist.is_available() and dist.is_initialized(): + from torch.distributed.nn.functional import all_reduce + token_count = all_reduce(token_count, op=dist.ReduceOp.SUM, group=seq_parallel_group) + lrn_sum = all_reduce(lrn_sum, op=dist.ReduceOp.SUM, group=seq_parallel_group) + adv_sum = all_reduce(adv_sum, op=dist.ReduceOp.SUM, group=seq_parallel_group) + + return lrn_sum, adv_sum, token_count diff --git a/pipelinerl/finetune/types.py b/pipelinerl/finetune/types.py index 33194c90..a3c16f2e 100644 --- a/pipelinerl/finetune/types.py +++ b/pipelinerl/finetune/types.py @@ -53,6 +53,8 @@ class PipelineBatchEncoding(BaseModel): attention_mask: torch.LongTensor labels: torch.LongTensor position_ids: torch.LongTensor | None = None # Required when seq_packing=True + # Unique per-token segment identifier (e.g., original sequence id when packed) + segment_ids: torch.LongTensor | None = None rewards: torch.FloatTensor advantages: torch.FloatTensor @@ -72,7 +74,7 @@ class PipelineBatchEncoding(BaseModel): pixel_values: torch.FloatTensor | None = None image_grid_thw: torch.LongTensor | None = None - @field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'image_grid_thw', mode='before') + @field_validator('input_ids', 'attention_mask', 'labels', 'position_ids', 'image_grid_thw', 'segment_ids', mode='before') @classmethod def convert_to_long_tensor(cls, v: List[int] | torch.Tensor | None) -> torch.LongTensor | None: """Handle initialization of long tensors from different types.""" @@ -157,6 +159,7 @@ def make_slices(self, num_slices: int) -> list['PipelineBatchEncoding']: "attention_mask": self.attention_mask[:, bs[i]:bs[i + 1]], "labels": self.labels[:, bs[i]:bs[i + 1]], "position_ids": self.position_ids[:, bs[i]:bs[i + 1]] if self.position_ids is not None else None, + "segment_ids": self.segment_ids[:, bs[i]:bs[i + 1]] if self.segment_ids is not None else None, "rewards": self.rewards[:, bs[i]:bs[i + 1]], "advantages": self.advantages[:, bs[i]:bs[i + 1]], "ref_logprobs": self.ref_logprobs[:, bs[i]:bs[i + 1]], @@ -176,4 +179,3 @@ def make_slices(self, num_slices: int) -> list['PipelineBatchEncoding']: slices.append(PipelineBatchEncoding(**result)) return slices - diff --git a/pipelinerl/finetune/utils.py b/pipelinerl/finetune/utils.py index ae10b3df..4095f842 100644 --- a/pipelinerl/finetune/utils.py +++ b/pipelinerl/finetune/utils.py @@ -28,6 +28,7 @@ def create_sentinel_batch(device, tokenizer=None, model_version=0) -> PipelineBa labels = [-100] * length attention_mask = [1] * length position_ids = list(range(length)) + segment_ids = [0] * length # single segment id for the whole sentinel batch # Prepare fields for dummy values (only needed for reward, advantages, etc.) zeros = [0.0] * length @@ -38,6 +39,7 @@ def create_sentinel_batch(device, tokenizer=None, model_version=0) -> PipelineBa "attention_mask": torch.tensor(attention_mask, dtype=torch.long).reshape(1, -1), "labels": torch.tensor(labels, dtype=torch.long).reshape(1, -1), "position_ids": torch.tensor(position_ids, dtype=torch.long).reshape(1, -1), + "segment_ids": torch.tensor(segment_ids, dtype=torch.long).reshape(1, -1), "rewards": torch.tensor(zeros, dtype=torch.float).reshape(1, -1), "advantages": torch.tensor(zeros, dtype=torch.float).reshape(1, -1), "ref_logprobs": torch.tensor(zeros, dtype=torch.float).reshape(1, -1), From a1527d6a68855abfdc05167cac71b3b29e9cea79 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 04:10:43 +0000 Subject: [PATCH 07/35] pass group --- pipelinerl/finetune_loop.py | 7 +++++- pipelinerl/transformers_compat.py | 41 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 pipelinerl/transformers_compat.py diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index 0948e056..cfb04345 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -660,7 +660,12 @@ def toggle_sync(sync: bool): assert batch.seq_boundaries is not None update_ring_flash_attn_params(batch.seq_boundaries, seq_parallel_group) loss, this_step_rl_metrics = rl_step( - model, batch, training_metrics.completed_steps, final_train_steps, rl_config + model, + batch, + training_metrics.completed_steps, + final_train_steps, + rl_config, + seq_parallel_group=seq_parallel_group, ) if is_sentinel_batch: # zero out the loss and do not update the metrics diff --git a/pipelinerl/transformers_compat.py b/pipelinerl/transformers_compat.py new file mode 100644 index 00000000..5e2805d1 --- /dev/null +++ b/pipelinerl/transformers_compat.py @@ -0,0 +1,41 @@ +"""Compatibility helpers for dealing with transformers regressions.""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +_PATCHED = False + + +def ensure_mistral3_auto_causal_lm_registered() -> None: + """Register Mistral 3 config for AutoModelForCausalLM when missing.""" + global _PATCHED + if _PATCHED: + return + + try: + from transformers.models.auto import modeling_auto + except Exception as exc: # pragma: no cover - optional dependency guard + logger.debug("transformers unavailable; skipping Mistral 3 auto registration: %s", exc) + return + + if modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get("mistral3"): + _PATCHED = True + return + + modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES["mistral3"] = "Mistral3ForConditionalGeneration" + # Some utilities rely on MODEL_WITH_LM_HEAD, keep them in sync. + modeling_auto.MODEL_WITH_LM_HEAD_MAPPING_NAMES["mistral3"] = "Mistral3ForConditionalGeneration" + + try: + from transformers import AutoModelForCausalLM + from transformers.models.mistral3.configuration_mistral3 import Mistral3Config + + # Touch the lazy mapping once so the entry is registered without materializing weights. + _ = AutoModelForCausalLM._model_mapping[Mistral3Config] + except Exception as exc: # pragma: no cover - optional dependency guard + logger.debug("Unable to prime Mistral 3 causal LM mapping: %s", exc) + + _PATCHED = True From 2f8a3ac7b82be8d82e3feeb107c127f37ecf0cd2 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 04:20:36 +0000 Subject: [PATCH 08/35] fix logging --- pipelinerl/finetune/rl/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index c2438f9a..04edc6e0 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -301,6 +301,10 @@ def rl_step( clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old surr2 = clamped_group_ratio * group_advantages_t policy_loss_total = -torch.min(surr1, surr2).sum() + expanded_indicators = torch.zeros_like(masks_shifted, dtype=torch.float) + for (start, end), val in zip(segments, clamp_log_ratio_new_old_indicators.squeeze()): + expanded_indicators[start:end] = float(val) + clamp_log_ratio_new_old_indicators = expanded_indicators case _: raise ValueError(f"Unknown algorithm {config.policy_loss}") @@ -376,9 +380,9 @@ def rl_step( "clamp_log_ratio_ref_new_indicator": sum_sum( clamp_log_ratio_ref_new_indicators / num_labels_in_seq, masks_shifted, segments ).item(), - #"clamp_log_ratio_new_old_indicator": sum_sum( - # clamp_log_ratio_new_old_indicators / num_labels_in_seq, masks_shifted, segments - #).item(), + "clamp_log_ratio_new_old_indicator": sum_sum( + clamp_log_ratio_new_old_indicators / num_labels_in_seq, masks_shifted, segments + ).item(), #"num_nans": torch.isnan(loss).sum().item(), "token_weight": sum_sum(tokens_weights / num_labels_in_seq, masks_shifted, segments).item(), "max_token_weight": tokens_weights[masks_shifted].max().item(), From 8d44afdf20b53aab894e78aa3c745ef33559cd78 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 04:37:01 +0000 Subject: [PATCH 09/35] upd --- pipelinerl/finetune/rl/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 04edc6e0..de8a9686 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -302,8 +302,10 @@ def rl_step( surr2 = clamped_group_ratio * group_advantages_t policy_loss_total = -torch.min(surr1, surr2).sum() expanded_indicators = torch.zeros_like(masks_shifted, dtype=torch.float) - for (start, end), val in zip(segments, clamp_log_ratio_new_old_indicators.squeeze()): - expanded_indicators[start:end] = float(val) + # Expand per-sequence indicators to token-level across segment ranges + # Flatten to 1-D so single-sequence cases don't produce 0-d tensors + for (start, end), val in zip(segments, clamp_log_ratio_new_old_indicators.flatten()): + expanded_indicators[0, start:end] = float(val) clamp_log_ratio_new_old_indicators = expanded_indicators case _: raise ValueError(f"Unknown algorithm {config.policy_loss}") From f8e026e884e1558c6495dad4ea1ec01d09ca79ca Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 05:01:53 +0000 Subject: [PATCH 10/35] sentinel batch --- pipelinerl/finetune/rl/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index 3cc3265f..cb3a603f 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -129,6 +129,8 @@ def per_segment_sums( # Slice and unify device/dtypes seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long) + if seg.numel() == 0: + return (torch.ones(0), torch.ones(0), torch.ones(0)) local_max = seg.max().to(torch.int64) if seq_parallel_group is None or not dist.is_available() or not dist.is_initialized(): From bf2fce9ae6b7effc110df02a6bd6950cd9988a39 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 05:06:32 +0000 Subject: [PATCH 11/35] uod --- pipelinerl/finetune/rl/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index cb3a603f..01b7ef97 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -130,7 +130,14 @@ def per_segment_sums( # Slice and unify device/dtypes seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long) if seg.numel() == 0: - return (torch.ones(0), torch.ones(0), torch.ones(0)) + # keep requires_grad consistent + device = log_ratio_new_old.device + dtype = log_ratio_new_old.dtype + requires_grad = log_ratio_new_old.requires_grad or advantages.requires_grad + + make_zeros = lambda: torch.zeros(0, dtype=dtype, device=device, requires_grad=requires_grad) + return make_zeros(), make_zeros(), make_zeros() + local_max = seg.max().to(torch.int64) if seq_parallel_group is None or not dist.is_available() or not dist.is_initialized(): From 39c997cdeec82b77db956a67ed67954d22c10d3c Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 05:12:53 +0000 Subject: [PATCH 12/35] typo --- pipelinerl/finetune/rl/utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index 01b7ef97..9296056c 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -131,12 +131,7 @@ def per_segment_sums( seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long) if seg.numel() == 0: # keep requires_grad consistent - device = log_ratio_new_old.device - dtype = log_ratio_new_old.dtype - requires_grad = log_ratio_new_old.requires_grad or advantages.requires_grad - - make_zeros = lambda: torch.zeros(0, dtype=dtype, device=device, requires_grad=requires_grad) - return make_zeros(), make_zeros(), make_zeros() + return log_ratio_new_old, advantages, torch.ones(0, device=log_ratio_new_old.device, dtype=log_ratio_new_old.dtype) local_max = seg.max().to(torch.int64) From 760c3c54bea0400b10bc4499971652c0eaa63283 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 05:16:54 +0000 Subject: [PATCH 13/35] typo --- pipelinerl/finetune/rl/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index 9296056c..a7580f0b 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -131,7 +131,7 @@ def per_segment_sums( seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long) if seg.numel() == 0: # keep requires_grad consistent - return log_ratio_new_old, advantages, torch.ones(0, device=log_ratio_new_old.device, dtype=log_ratio_new_old.dtype) + return log_ratio_new_old, advantages, torch.ones_like(log_ratio_new_old, device=log_ratio_new_old.device, dtype=log_ratio_new_old.dtype) local_max = seg.max().to(torch.int64) From 8f4a09cfbd6f206dc203d4951dac7661de2bbe79 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 12:26:04 +0000 Subject: [PATCH 14/35] fix rank --- pipelinerl/finetune/rl/utils.py | 43 ++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index a7580f0b..aae3a0f5 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -128,19 +128,26 @@ def per_segment_sums( raise ValueError(f"Expected segment_ids of shape [1, L], got {tuple(segment_ids.shape)}") # Slice and unify device/dtypes - seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long) - if seg.numel() == 0: - # keep requires_grad consistent - return log_ratio_new_old, advantages, torch.ones_like(log_ratio_new_old, device=log_ratio_new_old.device, dtype=log_ratio_new_old.dtype) - - local_max = seg.max().to(torch.int64) - + # Always compute a consistent number of collectives across ranks to avoid NCCL deadlocks. + # We cannot call seg.max() on empty tensors, so handle that path explicitly while still + # participating in all necessary collectives. + seg = segment_ids[:, 1:].contiguous().squeeze(0).to(dtype=torch.long, device=log_ratio_new_old.device) + seg_is_empty = seg.numel() == 0 + + # Determine n_segments. For distributed, we first all-reduce the local max (or -1 for empty) + # so all ranks agree on a global n_segments. This preserves the collective call even for empty ranks. if seq_parallel_group is None or not dist.is_available() or not dist.is_initialized(): - n_segments = int(local_max.item()) + 1 + if seg_is_empty: + n_segments = 0 + else: + n_segments = int(seg.max().to(torch.int64).item()) + 1 else: + local_max = torch.tensor(-1, dtype=torch.int64, device=log_ratio_new_old.device) + if not seg_is_empty: + local_max = seg.max().to(torch.int64) global_max = local_max.clone() dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=seq_parallel_group) - n_segments = int(global_max.item()) + 1 + n_segments = int(global_max.item()) + 1 # will be 0 if all ranks are empty mask = masks_shifted[:, :seg.numel()].contiguous().squeeze(0) lrn = log_ratio_new_old[:, :seg.numel()].contiguous().squeeze(0) @@ -159,7 +166,7 @@ def per_segment_sums( if valid.ndim != 1 or valid.shape[0] != seg.shape[0]: raise ValueError("Mask shape mismatch after alignment with segment_ids.") - if valid.any(): + if (not seg_is_empty) and valid.any(): seg_v = seg[valid] # indices actually used w_v = mask[valid] # weights for counts lrn_v = lrn[valid] @@ -192,11 +199,19 @@ def per_segment_sums( lrn_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) adv_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) - # Optional all-reduce across sequence-parallel group + # Optional all-reduce across sequence-parallel group. Ensure all ranks perform the same + # number of collectives, even when n_segments == 0 locally (e.g., sentinel micro-batches). if seq_parallel_group is not None and dist.is_available() and dist.is_initialized(): from torch.distributed.nn.functional import all_reduce - token_count = all_reduce(token_count, op=dist.ReduceOp.SUM, group=seq_parallel_group) - lrn_sum = all_reduce(lrn_sum, op=dist.ReduceOp.SUM, group=seq_parallel_group) - adv_sum = all_reduce(adv_sum, op=dist.ReduceOp.SUM, group=seq_parallel_group) + if n_segments == 0: + # Perform three dummy all-reduces to match the non-empty path (token_count, lrn_sum, adv_sum) + dummy = torch.zeros(1, dtype=lrn.dtype, device=device) + _ = all_reduce(dummy, op=dist.ReduceOp.SUM, group=seq_parallel_group) + _ = all_reduce(dummy, op=dist.ReduceOp.SUM, group=seq_parallel_group) + _ = all_reduce(dummy, op=dist.ReduceOp.SUM, group=seq_parallel_group) + else: + token_count = all_reduce(token_count, op=dist.ReduceOp.SUM, group=seq_parallel_group) + lrn_sum = all_reduce(lrn_sum, op=dist.ReduceOp.SUM, group=seq_parallel_group) + adv_sum = all_reduce(adv_sum, op=dist.ReduceOp.SUM, group=seq_parallel_group) return lrn_sum, adv_sum, token_count From e012dfeec8d0e804deeba3eefb726194418ee833 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 13:02:22 +0000 Subject: [PATCH 15/35] loss with no grad --- pipelinerl/finetune/rl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index de8a9686..a2a3a835 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -300,7 +300,7 @@ def rl_step( clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old surr2 = clamped_group_ratio * group_advantages_t - policy_loss_total = -torch.min(surr1, surr2).sum() + policy_loss_total = -torch.min(surr1, surr2).sum() if not batch.sentinel else group_ratio_new_old.sum() * 0.0 expanded_indicators = torch.zeros_like(masks_shifted, dtype=torch.float) # Expand per-sequence indicators to token-level across segment ranges # Flatten to 1-D so single-sequence cases don't produce 0-d tensors From 4779d1ea25d7a9aa8880d505f2f04ee6eb96f839 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 28 Sep 2025 14:50:56 +0000 Subject: [PATCH 16/35] fix grad --- pipelinerl/finetune/rl/__init__.py | 7 +++++-- pipelinerl/finetune/rl/utils.py | 25 +++++++++++++++++++------ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index a2a3a835..ab9fbdac 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -286,7 +286,6 @@ def rl_step( if segments is None: raise ValueError("GSPO loss requires packed sequences with segments") # Aggregate per-sequence means over valid (labeled) tokens only. - # Skip sequences with zero labeled tokens to avoid NaNs from mean([]). lrn_sum, adv_sum, tok_count = per_segment_sums( batch.segment_ids, masks_shifted, @@ -300,7 +299,11 @@ def rl_step( clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old surr2 = clamped_group_ratio * group_advantages_t - policy_loss_total = -torch.min(surr1, surr2).sum() if not batch.sentinel else group_ratio_new_old.sum() * 0.0 + # If we have a sentinel or no segments, return a zero loss but keep graph + if batch.sentinel or surr1.numel() == 0: + policy_loss_total = new_logprobs[..., :1].sum() * 0.0 + else: + policy_loss_total = -torch.min(surr1, surr2).sum() expanded_indicators = torch.zeros_like(masks_shifted, dtype=torch.float) # Expand per-sequence indicators to token-level across segment ranges # Flatten to 1-D so single-sequence cases don't produce 0-d tensors diff --git a/pipelinerl/finetune/rl/utils.py b/pipelinerl/finetune/rl/utils.py index aae3a0f5..8d3b0593 100644 --- a/pipelinerl/finetune/rl/utils.py +++ b/pipelinerl/finetune/rl/utils.py @@ -31,11 +31,19 @@ def mask_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[int] = Non def mask_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: - """Compute mean of tensor with a masked values.""" + """Compute mean of tensor with masked values, safely handling empty masks. + + Uses a clamped denominator so the result is 0 when the mask sums to 0, + while keeping the computation connected to the graph via the numerator. + """ if axis is not None: - return (values * mask).nan_to_num(0).sum(axis=axis) / mask.sum(axis=axis) # type: ignore + num = (values * mask).nan_to_num(0).sum(axis=axis) # type: ignore + den = mask.sum(axis=axis).clamp(min=1).to(dtype=values.dtype) # type: ignore + return num / den else: - return (values * mask).nan_to_num(0).sum() / mask.sum() + num = (values * mask).nan_to_num(0).sum() + den = mask.sum().clamp(min=1).to(dtype=values.dtype) + return num / den def mean_sum(values: torch.Tensor, masks: torch.Tensor, segments: list | None): @@ -194,10 +202,15 @@ def per_segment_sums( lrn_sum.index_add_(0, seg_v, lrn_v * w_v) adv_sum.index_add_(0, seg_v, adv_v * w_v) else: - # No valid tokens: return zeros + # No valid tokens: return zeros, but keep graph connectivity so downstream + # losses still "require_grad" (DeepSpeed/PyTorch expect this). + # Create zero scalars tied to inputs; gradients will be zero but the graph remains. + zero_from_lrn = (lrn * 0).sum() # requires_grad if lrn requires_grad + zero_from_adv = (adv * 0).sum() # requires_grad if adv requires_grad token_count = torch.zeros(n_segments, dtype=lrn.dtype, device=device) - lrn_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) - adv_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + # Broadcastable add keeps autograd connection even when n_segments > 0 + lrn_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + zero_from_lrn + adv_sum = torch.zeros(n_segments, dtype=lrn.dtype, device=device) + zero_from_adv # Optional all-reduce across sequence-parallel group. Ensure all ranks perform the same # number of collectives, even when n_segments == 0 locally (e.g., sentinel micro-batches). From ec1dbb7bb5c08de14604091a98c89a000e455063 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Mon, 29 Sep 2025 23:56:00 +0000 Subject: [PATCH 17/35] longer timeout --- pipelinerl/actor.py | 2 +- pipelinerl/finetune/rl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 475e1ee7..47568b4c 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -186,7 +186,7 @@ async def rollout_and_maybe_produce_result( last_logged = time.time() logger.info("Starting rollout scheduler") connector = aiohttp.TCPConnector(limit=50000, limit_per_host=50000, keepalive_timeout=1.0) - timeout = aiohttp.ClientTimeout(total=3600.0, connect=3600.0, sock_read=3600.0) + timeout = aiohttp.ClientTimeout(total=10*3600.0, connect=10*3600.0, sock_read=10*3600.0) async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: if time.time() - last_logged > 10.0 and sum(active_rollouts): diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index ab9fbdac..733e33ef 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -294,7 +294,7 @@ def rl_step( seq_parallel_group=seq_parallel_group, ) group_ratio_new_old = torch.exp(lrn_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2) - group_advantages_t = (adv_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2) + group_advantages_t = (adv_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2).detach() surr1 = group_ratio_new_old * group_advantages_t clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old From bdf7610dfefd7bfcaade9faa2564fe4de2e5e62b Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Tue, 30 Sep 2025 23:32:11 +0000 Subject: [PATCH 18/35] vllm quant --- conf/base.yaml | 7 +- pipelinerl/vllm0.py | 11 +- pipelinerl/vllm1.py | 8 + pipelinerl/vllm_quantization.py | 287 ++++++++++++++++++++++++++++++++ 4 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 pipelinerl/vllm_quantization.py diff --git a/conf/base.yaml b/conf/base.yaml index ac44fdde..82148e75 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -47,7 +47,7 @@ llm: temperature: 1.0 test_llm: parameters: - max_tokens: 16000 + max_tokens: 8192 temperature: 1.0 top_p: 0.95 top_k: 50 @@ -67,6 +67,8 @@ vllm_config: tensor-parallel-size: 1 pipeline-parallel-size: 1 generation-config: vllm + max_model_len: 12000 + quantization: bf16_last_layer_fp32 world: replicas: 1 @@ -75,7 +77,8 @@ world: preprocessor_fraction: 0 finetune_fraction: 4 - env_replicas: 2 + # Number of environment servers per actor VLLM server + env_replicas_per_actor: 1 actor_group_port: 9000 environment_start_port: 7777 diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 92c51085..950310be 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -36,6 +36,7 @@ import torch.distributed as dist from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest import pipelinerl.torch_utils +import pipelinerl.vllm_quantization # noqa: F401 - registers custom quantization configs logger = logging.getLogger(__name__) # configure this logger individually, in order to avoid messign @@ -47,6 +48,14 @@ handler.setFormatter(formatter) logger.addHandler(handler) +# Ensure quantization module logs are visible, too. +_qlogger = logging.getLogger("pipelinerl.vllm_quantization") +_qlogger.setLevel(logging.INFO) +# Avoid duplicate handlers if this module reloads. +if not _qlogger.handlers: + _qlogger.addHandler(handler) +_qlogger.propagate = False + def make_worker_class(multi_step: bool): base_class = MultiStepWorker if multi_step else Worker @@ -275,4 +284,4 @@ def run_llm(): args = parser.parse_args() validate_parsed_serve_args(args) - uvloop.run(run_server(args)) \ No newline at end of file + uvloop.run(run_server(args)) diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 80cba297..a5b69e82 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -27,6 +27,7 @@ from pipelinerl.finetune_loop import WeightUpdateRequest from typing import Any, Protocol, runtime_checkable import pipelinerl.torch_utils +import pipelinerl.vllm_quantization # noqa: F401 - registers custom quantization configs logger = logging.getLogger(__name__) # configure this logger individually, in order to avoid messign @@ -38,6 +39,13 @@ handler.setFormatter(formatter) logger.addHandler(handler) +# Ensure quantization module logs are visible, too. +_qlogger = logging.getLogger("pipelinerl.vllm_quantization") +_qlogger.setLevel(logging.INFO) +if not _qlogger.handlers: + _qlogger.addHandler(handler) +_qlogger.propagate = False + @runtime_checkable class LikeWorker(Protocol): diff --git a/pipelinerl/vllm_quantization.py b/pipelinerl/vllm_quantization.py new file mode 100644 index 00000000..af63333e --- /dev/null +++ b/pipelinerl/vllm_quantization.py @@ -0,0 +1,287 @@ +import os +import torch +import logging +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding) + +logger = logging.getLogger(__name__) + +@register_quantization_config("bf16_last_layer_fp32") +class BF16WithLastLayerFP32(QuantizationConfig): + """ + A custom mixed-precision configuration for vLLM. + + This configuration keeps the last layer in float32 for maximum precision + while running all other layers using bfloat16 for improved performance + and reduced memory usage. + """ + + def __init__(self, config: object | None = None): + super().__init__() + self.default_dtype = self._resolve_default_dtype(config) + logger.info( + "Initialized quantization config '%s' with default_dtype=%s; last layer forced to %s", + self.get_name(), + self.default_dtype, + torch.float32, + ) + + @staticmethod + def _resolve_default_dtype(config: object | None) -> torch.dtype: + """Best-effort extraction of the default dtype for non-final layers.""" + + if config is None: + return torch.bfloat16 + + # Passed a torch dtype directly. + if isinstance(config, torch.dtype): + return config + + # Allow string representations. + if isinstance(config, str): + return _string_to_dtype(config) + + # HuggingFace model config or similar objects typically expose `dtype`. + dtype = getattr(config, "dtype", None) + if dtype is not None: + if isinstance(dtype, torch.dtype): + return dtype + return _string_to_dtype(dtype) + + # Dictionary based quantization configs can also specify the dtype. + if isinstance(config, dict): + for key in ("default_dtype", "activation_dtype", "dtype"): + value = config.get(key) + if value is None: + continue + if isinstance(value, torch.dtype): + return value + return _string_to_dtype(value) + + return torch.bfloat16 + + def get_supported_activations(self) -> list[str]: + # vLLM does not currently query this helper, but keep it for completeness. + return ["relu", "gelu", "silu"] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> LinearMethodBase | None: + """Return quantization method for the given layer.""" + + is_last_layer = prefix.endswith("lm_head") or prefix.endswith( + ".lm_head") or (".lm_head." in prefix) or ("lm_head." in prefix) + + # Heuristic: many models (e.g., LLaMA/Qwen) tie the output projection + # (unembedding) to the input token embedding. When weight tying is used, + # there is no distinct Linear lm_head module. In that case, force the + # embedding weights to float32 so the final logits use FP32 params. + tied_unembed = False + if isinstance(layer, VocabParallelEmbedding): + name = prefix.lower() + if any(s in name for s in ["embed_tokens", "tok_embeddings", "word_embeddings", "wte"]): + tied_unembed = True + + if isinstance(layer, VocabParallelEmbedding): + if is_last_layer: + # Final logits matmul should run in FP32. Use a method that + # upcasts activations and uses an FP32 weight tensor (reusing + # an existing FP32 param if present to avoid copies). + logger.info( + "Quant config forcing FP32 logits matmul for %s (%s)", + prefix, + layer.__class__.__name__, + ) + return _FP32UnembedEmbeddingMethod() + if tied_unembed: + # Keep BF16 parameters for input embedding, but force logits + # matmul to FP32 via a custom apply() used by LogitsProcessor. + logger.info( + "Detected tied embedding %s (%s); using FP32 unembedding matmul while keeping BF16 params.", + prefix, + layer.__class__.__name__, + ) + return _FP32UnembedEmbeddingMethod() + logger.info("Quant config leaving embedding %s (%s) unmodified", prefix, layer.__class__.__name__) + return None + + if isinstance(layer, LinearBase): + if is_last_layer: + logger.info("Quant config forcing FP32 linear layer for %s (%s)", prefix, layer.__class__.__name__) + # Also pin activations to FP32 via a pre-hook for clarity. + try: + def _cast_input_to_fp32(_mod, _inp): + if not _inp: + return _inp + args = list(_inp) + if isinstance(args[0], torch.Tensor) and args[0].dtype != torch.float32: + logger.info("Casting input activation to FP32 for %s (%s)", prefix, layer.__class__.__name__) + args[0] = args[0].to(torch.float32) + return tuple(args) + + # Avoid duplicate hooks on reload + if not hasattr(layer, "_pipelinerl_fp32_hook_installed"): + layer.register_forward_pre_hook(_cast_input_to_fp32) + layer._pipelinerl_fp32_hook_installed = True + except Exception as _hook_err: # pragma: no cover - defensive logging + logger.warning("Failed to install FP32 input hook for %s: %s", prefix, _hook_err) + return _ForcedDTypeLinearMethod(torch.float32) + logger.info("Quant config setting dtype %s for %s (%s)", self.default_dtype, prefix, layer.__class__.__name__) + return _ForcedDTypeLinearMethod(self.default_dtype) + + logger.info("Quant config leaving module %s (%s) unmodified", prefix, layer.__class__.__name__) + return None + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + dtypes = [self.default_dtype] + if torch.float32 not in dtypes: + dtypes.append(torch.float32) + return dtypes + + @classmethod + def get_min_capability(cls) -> int: + return 0 + + @classmethod + def get_name(cls) -> str: + return "bf16_last_layer_fp32" + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict | None = None): + return cls(config) + + def get_tp_size(self) -> int: + return 1 + + def get_tp_group(self, tp_size: int): + return None + + def get_supported_dtypes(self) -> list[torch.dtype]: + dtypes = [self.default_dtype] + if torch.float32 not in dtypes: + dtypes.append(torch.float32) + return dtypes + + + +class _ForcedDTypeLinearMethod(UnquantizedLinearMethod): + """Linear method that enforces a specific parameter dtype.""" + + def __init__(self, target_dtype: torch.dtype): + super().__init__() + self._target_dtype = target_dtype + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + logger.info("Creating linear weights for %s with dtype %s", layer.__class__.__name__, self._target_dtype) + return super().create_weights(layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype=self._target_dtype, + **extra_weight_attrs) + + +class _ForcedDTypeEmbeddingMethod(UnquantizedEmbeddingMethod): + """Embedding method that enforces a specific parameter dtype.""" + + def __init__(self, target_dtype: torch.dtype): + super().__init__() + self._target_dtype = target_dtype + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + logger.info("Creating embedding weights for %s with dtype %s", layer.__class__.__name__, self._target_dtype) + return super().create_weights(layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype=self._target_dtype, + **extra_weight_attrs) + + +class _FP32UnembedEmbeddingMethod(UnquantizedEmbeddingMethod): + """Use BF16 weights in module state, but compute logits matmul in FP32. + + This method does NOT change the dtype of the embedding weights stored in + the module (so input token embeddings remain in BF16). It only casts the + hidden states and the weight to FP32 inside apply(), which is used by the + LogitsProcessor to compute final logits via F.linear. + """ + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] + # Log once to avoid spamming per token. + if not hasattr(layer, "_pipelinerl_fp32_unembed_logged"): + logger.info("Computing unembedding (logits) in FP32 for %s", layer.__class__.__name__) + layer._pipelinerl_fp32_unembed_logged = True + + # Upcast activations to FP32 for the final matmul. + x32 = x if x.dtype == torch.float32 else x.to(torch.float32) + target_device = x32.device + + # Use the existing FP32 parameter directly if already suitable to + # avoid extra copies. Otherwise, keep a cached FP32 copy per device. + w_param = layer.weight + if w_param.dtype == torch.float32 and w_param.device == target_device: + w32 = w_param + else: + w_attr = "_pipelinerl_fp32_unembed_weight" + w32 = getattr(layer, w_attr, None) + needs_new_copy = ( + w32 is None or + w32.dtype != torch.float32 or + w32.device != target_device or + w32.shape != w_param.shape + ) + if needs_new_copy: + # Only copy when necessary. + w32 = w_param.to(device=target_device, dtype=torch.float32) + setattr(layer, w_attr, w32) + + b32 = None + if bias is not None: + b32 = bias if bias.dtype == torch.float32 else bias.to(torch.float32) + if b32.device != target_device: + b32 = b32.to(target_device) + + return torch.nn.functional.linear(x32, w32, b32) + + +def _string_to_dtype(value: str) -> torch.dtype: + normalized = value.lower().replace("torch.", "") + mapping = { + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, + "float16": torch.float16, + "fp16": torch.float16, + "half": torch.float16, + } + try: + return mapping[normalized] + except KeyError as exc: # pragma: no cover - defensive + raise ValueError(f"Unsupported dtype string: {value}") from exc + + +if __name__ == "__main__": + get_quantization_config("bf16_last_layer_fp32") From 5836c69f7c9b9892cf9b510566dd7a9831bf575e Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Tue, 30 Sep 2025 23:46:37 +0000 Subject: [PATCH 19/35] no quant baseline --- conf/base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conf/base.yaml b/conf/base.yaml index 82148e75..1d91cdb9 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -68,7 +68,7 @@ vllm_config: pipeline-parallel-size: 1 generation-config: vllm max_model_len: 12000 - quantization: bf16_last_layer_fp32 + #quantization: bf16_last_layer_fp32 world: replicas: 1 From e8aaa041821e733459d581df872083d8a4f83f10 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Thu, 2 Oct 2025 02:25:33 +0000 Subject: [PATCH 20/35] log less often --- pipelinerl/actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 47568b4c..a675de55 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -189,7 +189,7 @@ async def rollout_and_maybe_produce_result( timeout = aiohttp.ClientTimeout(total=10*3600.0, connect=10*3600.0, sock_read=10*3600.0) async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: - if time.time() - last_logged > 10.0 and sum(active_rollouts): + if time.time() - last_logged > 30.0 and sum(active_rollouts): logger.info( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " From 1aa95fbad270ee1f5f496917d55433e5621bb6d8 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Thu, 2 Oct 2025 02:28:58 +0000 Subject: [PATCH 21/35] typo --- pipelinerl/finetune/rl/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 733e33ef..43d0ae48 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -46,7 +46,8 @@ class RLConfig(BaseModel): default=True, description="Use advantages instead of rewards to compute the loss", ) - epsilon: float = Field(default=0.2, description="Clip parameter for the ration of log probs") + epsilon_low: float = Field(default=0.2, description="Clip parameter for the ration of log probs") + epsilon_high: float = Field(default=0.2, description="Clip parameter for the ration of log probs") batch_size: int = Field(default=0, description="Batch size is required for normalization") reward_minus_kl_coef: float = Field( default=0.0, @@ -272,15 +273,15 @@ def rl_step( match config.policy_loss: case "ppo": surr1 = ratio_new_old * log_p_weights - clamped_ratio = torch.clamp(ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) + clamped_ratio = torch.clamp(ratio_new_old, 1 - config.epsilon_low, 1 + config.epsilon_high) clamp_log_ratio_new_old_indicators = clamped_ratio != ratio_new_old surr2 = clamped_ratio * log_p_weights policy_loss = torch.min(surr1, surr2) case "reinforce": surr1 = torch.zeros_like(ratio_new_old) surr2 = torch.zeros_like(ratio_new_old) - clamp_log_ratio_new_old_indicators = ratio_new_old > 1 + config.epsilon - ratio_new_old = torch.clamp(ratio_new_old, 0, 1 + config.epsilon) + clamp_log_ratio_new_old_indicators = ratio_new_old > 1 + config.epsilon_high + ratio_new_old = torch.clamp(ratio_new_old, 0, 1 + config.epsilon_high) policy_loss = new_logprobs * log_p_weights * ratio_new_old.detach() case "gspo": if segments is None: @@ -296,7 +297,7 @@ def rl_step( group_ratio_new_old = torch.exp(lrn_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2) group_advantages_t = (adv_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2).detach() surr1 = group_ratio_new_old * group_advantages_t - clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon, 1 + config.epsilon) + clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon_low, 1 + config.epsilon_high) clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old surr2 = clamped_group_ratio * group_advantages_t # If we have a sentinel or no segments, return a zero loss but keep graph From 8dfa5752059b2ff62b738e15ada45057bb550974 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Thu, 2 Oct 2025 02:40:28 +0000 Subject: [PATCH 22/35] log every 30 seconds --- pipelinerl/actor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index a675de55..f39bb71e 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -189,7 +189,8 @@ async def rollout_and_maybe_produce_result( timeout = aiohttp.ClientTimeout(total=10*3600.0, connect=10*3600.0, sock_read=10*3600.0) async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: - if time.time() - last_logged > 30.0 and sum(active_rollouts): + #if time.time() - last_logged > 30.0 and sum(active_rollouts): + if int(time.time())%30 == 0 and sum(active_rollouts): logger.info( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " From 140c3f21ccd2fa8ae77c70c5ac7f7a44ccfa8d04 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Thu, 2 Oct 2025 02:43:29 +0000 Subject: [PATCH 23/35] upd --- pipelinerl/actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index f39bb71e..cad8f3b9 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -190,7 +190,7 @@ async def rollout_and_maybe_produce_result( async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: #if time.time() - last_logged > 30.0 and sum(active_rollouts): - if int(time.time())%30 == 0 and sum(active_rollouts): + if int(time.time()) % 10 == 0 and sum(active_rollouts): logger.info( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " From 935356ae175a8b01d0190427359efbe23ab7cf96 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Thu, 2 Oct 2025 02:56:24 +0000 Subject: [PATCH 24/35] more logging --- pipelinerl/actor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index cad8f3b9..9ffabe7e 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -189,8 +189,7 @@ async def rollout_and_maybe_produce_result( timeout = aiohttp.ClientTimeout(total=10*3600.0, connect=10*3600.0, sock_read=10*3600.0) async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: - #if time.time() - last_logged > 30.0 and sum(active_rollouts): - if int(time.time()) % 10 == 0 and sum(active_rollouts): + if time.time() - last_logged > 1.0 and sum(active_rollouts): logger.info( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " From c1347d9cb534e957a98a05d492d619d84e844e35 Mon Sep 17 00:00:00 2001 From: rafapi Date: Thu, 2 Oct 2025 19:39:36 +0000 Subject: [PATCH 25/35] Overflow --- pipelinerl/domains/math/rollouts.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 8b88d877..3afd021c 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -14,6 +14,7 @@ class Metrics(BaseMetrics): penalty: float + overflow: bool = False class RewardTable(BaseModel): wrong_answer_not_finished: float @@ -104,6 +105,11 @@ async def generate_math_rollout( no_error=answer_status != "unparsable", no_answer=answer_status == "no_answer", penalty=overlong_penalty, + overflow=bool( + trace.input_ids + and getattr(llm.tokenizer, "eos_token_id", None) is not None + and trace.input_ids[-1] != llm.tokenizer.eos_token_id + ), ) return RolloutResult( From 041058345791212aca6f6127e8d8e839c08d1315 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Fri, 3 Oct 2025 14:57:42 +0000 Subject: [PATCH 26/35] upd --- conf/base.yaml | 4 ++-- pipelinerl/preprocess.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/conf/base.yaml b/conf/base.yaml index 1d91cdb9..28f0007a 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -1,5 +1,5 @@ defaults: - - finetune: actor_critic + - finetune: base - rewards: pure_success - streams: files - _self_ @@ -78,7 +78,7 @@ world: finetune_fraction: 4 # Number of environment servers per actor VLLM server - env_replicas_per_actor: 1 + env_replicas: 1 actor_group_port: 9000 environment_start_port: 7777 diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index d758ff36..372c7b5b 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -17,6 +17,7 @@ from typing import List import datasets +import random import transformers from litellm import BaseModel, Field @@ -554,6 +555,7 @@ def run_preprocessing_loop( batch_done = False start_writing = time.time() + random.shuffle(processed_entries_queue) while (len(processed_entries_queue) > 0 and not batch_done) or (cfg.preprocess.dataset_buffer_size and not batch_done): logger.debug(f"[inner loop] trainer {trainer_id} has {samples_per_trainer[trainer_id]} samples, target is {target_samples_per_lead}") if cfg.finetune.seq_packing: From 8945d0a01b6b1097d5fbdb7f4df4e0762aa3df5a Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 5 Oct 2025 04:33:58 +0000 Subject: [PATCH 27/35] handle answer tags --- pipelinerl/domains/math/verifier_api.py | 45 +++++++++++++++---------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/pipelinerl/domains/math/verifier_api.py b/pipelinerl/domains/math/verifier_api.py index 7b3835a3..9fa9b4ef 100644 --- a/pipelinerl/domains/math/verifier_api.py +++ b/pipelinerl/domains/math/verifier_api.py @@ -88,34 +88,41 @@ def verify_answer(prediction: str, gold: str, strict: bool = True, max_predictio def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: + import re + try: - # Input Sanitization / Validation (very important) + # Input Sanitization / Validation if not isinstance(prediction, str) or not isinstance(gold, str): raise ValueError("Prediction and gold must be strings") + # Try extracting from \boxed{...} first boxed_start = prediction.rfind("\\boxed{") - if boxed_start < 0: - raise NoAnswerException() - - boxed_prediction = prediction[boxed_start:] - if "\\boxed{}" in boxed_prediction: - raise EmptyBoxedException() - - if len(boxed_prediction) > max_prediction_length: - raise UnparsableException() - + if boxed_start >= 0: + boxed_prediction = prediction[boxed_start:] + if "\\boxed{}" in boxed_prediction: + raise EmptyBoxedException() + if len(boxed_prediction) > max_prediction_length: + raise UnparsableException() + extracted_prediction = boxed_prediction + else: + # Fallback: look for ... tags + answer_match = re.findall(r"(.*?)", prediction, re.DOTALL) + if answer_match: + extracted_prediction = answer_match[-1].strip() # last one if multiple + else: + raise NoAnswerException() + + # Parse and verify gold_parsed = math_verify.parse(gold) - boxed_prediction_parsed = math_verify.parse(boxed_prediction) - if not boxed_prediction_parsed: + pred_parsed = math_verify.parse(extracted_prediction) + if not pred_parsed: raise ValueError("Failed to parse prediction.") with timeout(1): - equivalent = math_verify.verify(gold_parsed, boxed_prediction_parsed, strict=strict, timeout_seconds=1) - if equivalent: - answer_status = "correct" - else: - answer_status = "wrong" + equivalent = math_verify.verify(gold_parsed, pred_parsed, strict=strict, timeout_seconds=1) + + answer_status = "correct" if equivalent else "wrong" except Exception as e: match e: @@ -123,9 +130,11 @@ def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_ answer_status = "no_answer" case _: answer_status = "unparsable" + return answer_status + def verify_countdown(prediction: str, gold: str) -> str: target = eval(gold.split("-")[1]) numbers = eval(gold.split("-")[2]) From b2bda96a6f02244ef9f3244178f22137c99e65c1 Mon Sep 17 00:00:00 2001 From: Alex Piche Date: Sun, 5 Oct 2025 04:34:11 +0000 Subject: [PATCH 28/35] upd --- pipelinerl/domains/math/minimal_rollout.py | 72 ++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 pipelinerl/domains/math/minimal_rollout.py diff --git a/pipelinerl/domains/math/minimal_rollout.py b/pipelinerl/domains/math/minimal_rollout.py new file mode 100644 index 00000000..bc46d2a1 --- /dev/null +++ b/pipelinerl/domains/math/minimal_rollout.py @@ -0,0 +1,72 @@ +import time +import random + +import aiohttp +from omegaconf import DictConfig +from pydantic import BaseModel +from pipelinerl.rollouts import RolloutResult, BaseMetrics +from pipelinerl.world import Job +from tapeagents.core import Prompt +from tapeagents.llms.trainable import TrainableLLM + +from pipelinerl.async_llm import llm_async_generate, make_training_text +from .verifier_api import verify_answer_rpc + +class Metrics(BaseMetrics): + pass + +class RewardTable(BaseModel): + wrong_answer_not_finished: float + wrong_answer_finished: float + no_answer_not_finished: float + no_answer_finished: float + unparsable_not_finished: float + unparsable_finished: float + correct_answer_not_finished: float + correct_answer_finished: float + buffer_tokens: int = 0 # 0 means no overlong reward shaping + +def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: + """ + Compute the overlong penalty + """ + if sequence_length > (max_length - buffer_tokens) and sequence_length <= max_length: + return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens + return 0. + +def get_reward(trace, answer_status: str, rewards: RewardTable) -> float: + pass + + +async def generate_math_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + messages = [] + if cfg.actor.system_prompt: + messages.append({"role": "system", "content": cfg.actor.system_prompt}) + messages.append({"role": "user", "content": f"{problem['task']} \n{cfg.actor.task_prompt}"}) + prompt = Prompt(messages=messages) + + time_start = time.time() + llm_call = await llm_async_generate(llm, prompt, session) + latency = time.time() - time_start + + assert llm_call.output.content is not None + rewards = RewardTable(**dict(cfg.rewards)) + + env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] + env_job = random.choice(env_jobs) + assert env_job.port is not None + answer_status = await verify_answer_rpc(session=session, host=env_job.hostname, port=env_job.port, prediction=llm_call.output.content, gold=problem["answer"]) + + trace = make_training_text(llm, llm_call) + reward = get_reward(trace, answer_status, rewards) + trace.reward = reward + + metrics = Metrics(reward=reward, success=answer_status == "correct", no_error=answer_status != "unparsable", no_answer=answer_status == "no_answer") + + + return RolloutResult(training_texts=[trace], metrics=metrics, latency=latency, dataset_name=problem.get("dataset")) From 3c12ffb737aabb100ebd48135dd1db75eb6fe297 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 7 Oct 2025 20:02:34 +0000 Subject: [PATCH 29/35] Add AIME2025 --- pipelinerl/domains/math/load_datasets.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py index 4b44dfb6..e78385ec 100644 --- a/pipelinerl/domains/math/load_datasets.py +++ b/pipelinerl/domains/math/load_datasets.py @@ -152,6 +152,26 @@ def load_math(split): return datasets.Dataset.from_list(data) +def _load_aime_2025_opencompass_dataset(upsample_factor: int = 0) -> list[dict]: + configs = ["AIME2025-I", "AIME2025-II"] + dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original") + + samples: list[dict] = [] + for config_name in configs: + ds = load_dataset("opencompass/AIME2025", config_name, split="test") + samples.extend([s for s in process_math(ds, dataset_name) if s is not None]) + + original_size = len(samples) + if upsample_factor > 0: + samples *= upsample_factor + + logger.info( + f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples" + + (f" (upsampled from {original_size})" if upsample_factor > 0 else "") + ) + return add_ids(samples) + + def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]: aime_dataset = load_dataset("AI-MO/aimo-validation-aime", split="train", trust_remote_code=True) aime_dataset = aime_dataset.filter(lambda x: str(year) in x["url"]) @@ -335,6 +355,12 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None if "aime_2024_original" in dataset_names: datasets += _load_aime_dataset(2024) + if "aime_2025" in dataset_names: + datasets += _load_aime_2025_opencompass_dataset(upsample_factor=16) + + if "aime_2025_original" in dataset_names: + datasets += _load_aime_2025_opencompass_dataset() + if "amc_2022" in dataset_names: # TODO: AMC 2022 is 43 problems, is that to be expected? datasets += _load_amc_dataset(2022, upsample_factor=16) From e6e44efee5955f2a6b41caff8ccbde07e4792cc7 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 10 Oct 2025 10:58:02 +0000 Subject: [PATCH 30/35] Do a bit more work to find answer --- pipelinerl/domains/math/rollouts.py | 156 +++++++++++++++++++++++++--- 1 file changed, 142 insertions(+), 14 deletions(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 3afd021c..ff9295f2 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -1,5 +1,7 @@ +import re import time import random +import logging import aiohttp from omegaconf import DictConfig @@ -12,9 +14,11 @@ from pipelinerl.async_llm import llm_async_generate, make_training_text from .verifier_api import verify_answer_rpc + class Metrics(BaseMetrics): penalty: float overflow: bool = False + auto_boxed: bool = False class RewardTable(BaseModel): wrong_answer_not_finished: float @@ -27,6 +31,85 @@ class RewardTable(BaseModel): correct_answer_finished: float buffer_tokens: int = 0 # 0 means no overlong reward shaping + +_BOXED_PREFIX = "\\boxed{" + + +def _find_last_boxed_span(text: str) -> tuple[int, int] | None: + start = text.rfind(_BOXED_PREFIX) + if start < 0: + return None + depth = 0 + for idx in range(start + len(_BOXED_PREFIX), len(text)): + ch = text[idx] + if ch == "{": + depth += 1 + elif ch == "}": + if depth == 0: + return start, idx + 1 + depth -= 1 + return None + + +_ANSWER_PREFIX_RE = re.compile( + r"^(final answer|answer|ans\.?|thus.*?is|therefore.*?is|so the answer is)[:=\-\s]*", + re.IGNORECASE, +) + + +def _strip_answer_prefix(line: str) -> str: + return _ANSWER_PREFIX_RE.sub("", line).strip() + + +_EXPRESSION_RE = re.compile(r"([-+]?\s*[^\s]+(?:\s*[^\s]+)*)") + + +def _extract_candidate_expression(text: str) -> str | None: + for raw_line in reversed(text.strip().splitlines()): + line = raw_line.strip() + if not line: + continue + line = _strip_answer_prefix(line.rstrip(".;!")) + if not line: + continue + if any(char.isdigit() for char in line) or "\\" in line: + return line + match = _EXPRESSION_RE.search(text.strip()) + return match.group(1).strip() if match else None + + +def ensure_boxed_answer(text: str) -> tuple[str, bool]: + """Return text contained in the last \boxed{} block.""" + if not text: + return text, False + + cleaned = text.rstrip() + span = _find_last_boxed_span(cleaned) + if span is not None: + start, end = span + prefix = cleaned[:start].rstrip() + boxed = cleaned[start:end] + suffix_adjusted = f"{prefix}\n{boxed}" if prefix else boxed + changed = suffix_adjusted != cleaned + return suffix_adjusted, changed + + candidate = _extract_candidate_expression(cleaned) + if not candidate: + return cleaned, False + + candidate_boxed = f"\\boxed{{{candidate}}}" + if candidate in cleaned: + prefix, sep, suffix = cleaned.rpartition(candidate) + if sep: + adjusted = f"{prefix}{candidate_boxed}{suffix}".rstrip() + else: + adjusted = cleaned + else: + adjusted = cleaned + if adjusted == cleaned: + adjusted = f"{cleaned}\n{candidate_boxed}" if cleaned else candidate_boxed + return adjusted, adjusted != cleaned + def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: """ Compute the overlong penalty @@ -44,7 +127,10 @@ async def generate_math_rollout( messages = [] if cfg.actor.system_prompt: messages.append({"role": "system", "content": cfg.actor.system_prompt}) - messages.append({"role": "user", "content": f"{problem['task']} \n{cfg.actor.task_prompt}"}) + user_content = problem["task"] + if cfg.actor.task_prompt: + user_content = f"{user_content} \n{cfg.actor.task_prompt}" + messages.append({"role": "user", "content": user_content}) prompt = Prompt(messages=messages) time_start = time.time() @@ -52,6 +138,12 @@ async def generate_math_rollout( latency = time.time() - time_start assert llm_call.output.content is not None + auto_boxed = False + if getattr(cfg.actor, "ensure_boxed_answers", False): + sanitized, changed = ensure_boxed_answer(llm_call.output.content) + if changed: + llm_call.output.content = sanitized + auto_boxed = True rewards = RewardTable(**dict(cfg.rewards)) discount_factor = cfg.actor.discount_factor @@ -60,14 +152,17 @@ async def generate_math_rollout( # choose the job randomly env_job = random.choice(env_jobs) assert env_job.port is not None - answer_status = await verify_answer_rpc( - session=session, - host=env_job.hostname, - port=env_job.port, - prediction=llm_call.output.content, - gold=problem["answer"], - strict=True, - ) + try: + answer_status = await verify_answer_rpc( + session=session, + host=env_job.hostname, + port=env_job.port, + prediction=llm_call.output.content, + gold=problem["answer"], + strict=True, + ) + except Exception as exc: + answer_status = "unparsable" trace = make_training_text(llm, llm_call) # Determine reward based on answer status and finished state @@ -99,17 +194,50 @@ async def generate_math_rollout( reward += overlong_penalty trace.reward = reward + # Prefer backend-provided finish reason if available; normalize for comparisons + if isinstance(trace.metadata, dict): + finish_reason_raw = trace.metadata.get("finish_reason") + else: + finish_reason_raw = None + + finish_reason = ( + str(finish_reason_raw).strip().lower() if finish_reason_raw is not None else None + ) + + # Overflow is true when generation hit the backend length cap explicitly + overflow_by_reason = finish_reason == "length" + + # Only fall back to heuristics when we lack a reliable finish reason (e.g., old backends) + use_fallback = finish_reason is None or finish_reason not in {"stop", "length"} + + overflow_by_length = False + overflow_by_eos = False + if use_fallback: + max_tokens = int(llm.parameters.get("max_tokens", 0) or 0) + if max_tokens > 0: + try: + overflow_by_length = trace.output_tokens >= max_tokens + except Exception: + overflow_by_length = False + + try: + eos_token_id = getattr(llm.tokenizer, "eos_token_id", None) + overflow_by_eos = bool( + trace.input_ids + and eos_token_id is not None + and trace.input_ids[-1] != eos_token_id + ) + except Exception: + overflow_by_eos = False + metrics = Metrics( reward=reward, success=answer_status == "correct", no_error=answer_status != "unparsable", no_answer=answer_status == "no_answer", penalty=overlong_penalty, - overflow=bool( - trace.input_ids - and getattr(llm.tokenizer, "eos_token_id", None) is not None - and trace.input_ids[-1] != llm.tokenizer.eos_token_id - ), + overflow=bool(overflow_by_reason or overflow_by_length or overflow_by_eos), + auto_boxed=auto_boxed, ) return RolloutResult( From bf3f403f92abdb7ac97f14890148721d583793a1 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 10 Oct 2025 10:58:48 +0000 Subject: [PATCH 31/35] Extract finish reason --- pipelinerl/async_llm.py | 47 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index e375b6a5..4d24bb89 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -7,6 +7,7 @@ from PIL import Image from tapeagents.core import LLMCall, LLMOutput, Prompt, TokenLogprob from tapeagents.llms.trainable import TrainableLLM +from omegaconf import DictConfig, ListConfig, OmegaConf from pipelinerl.finetune.data import MASKED_TOKEN_ID from pipelinerl.rollouts import TrainingText @@ -15,9 +16,19 @@ logger = logging.getLogger(__name__) +def _to_plain(obj): + """ConvertHydra containers to plain Python types for JSON.""" + if isinstance(obj, (DictConfig, ListConfig)): + return OmegaConf.to_container(obj, resolve=True) + if isinstance(obj, dict): + return {k: _to_plain(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_to_plain(v) for v in obj] + return obj + + def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]: """Extract PIL Images from multimodal messages.""" - images = [] for message in messages: if isinstance(message.get("content"), list): @@ -56,6 +67,8 @@ async def llm_async_generate( "messages": prompt.messages, "stream": llm.stream, } + if getattr(prompt, "tools", None): + data["tools"] = prompt.tools if llm.collect_logprobs: data.update( { @@ -67,9 +80,15 @@ async def llm_async_generate( logger.debug(f"POST request to {llm.base_url}/v1/chat/completions") + llm_params = _to_plain(getattr(llm, "parameters", {})) + if not isinstance(llm_params, dict): + llm_params = {"parameters": llm_params} + payload = _to_plain(data) + payload.update(llm_params) + async with session.post( url=f"{llm.base_url}/v1/chat/completions", - json=data | llm.parameters, + json=payload, headers=headers, ssl=False, ) as response: @@ -102,6 +121,12 @@ async def llm_async_generate( except Exception as e: logger.error(f"Failed to process logprobs: {logprob}") logger.error(e) + try: + finish_reason = data["choices"][0].get("finish_reason") + stop_reason = data["choices"][0].get("stop_reason") + except Exception: + finish_reason = None + stop_reason = None except Exception as e: logger.exception(f"Failed to parse llm response: {data}") raise e @@ -112,6 +137,11 @@ async def llm_async_generate( llm_call.output_length_tokens = data["usage"]["completion_tokens"] assert llm_call is not None, "llm_call is None" llm_call.logprobs = parsed_logprobs + # Store finish reason details for metrics + if finish_reason is not None: + llm_call.llm_info["finish_reason"] = finish_reason + if stop_reason is not None: + llm_call.llm_info["stop_reason"] = stop_reason return llm_call @@ -206,7 +236,17 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: # Apply masking to input tokens that aren't generated labels = [MASKED_TOKEN_ID] * len(prompt_token_ids) + labels logprobs = [lp.logprob for lp in llm_call.logprobs] - finished = llm_call.output.content.endswith(tokenizer.eos_token) + # Prefer backend-provided finish reason + finish_reason = None + try: + finish_reason = llm_call.llm_info.get("finish_reason") + except Exception: + finish_reason = None + + if finish_reason is not None: + finished = finish_reason == "stop" + else: + finished = llm_call.output.content.endswith(tokenizer.eos_token) prompt_tokens = llm_call.prompt_length_tokens output_tokens = llm_call.output_length_tokens @@ -220,4 +260,5 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: prompt_tokens=prompt_tokens, output_tokens=output_tokens, visual_features=visual_features, + metadata={"finish_reason": finish_reason} if finish_reason is not None else {}, ) From 43cfe000abde572af746d6cd7bba352b47e7a9fd Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 10 Oct 2025 10:59:35 +0000 Subject: [PATCH 32/35] Compute overlong and overflow --- pipelinerl/finetune/rl/__init__.py | 36 +++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 43d0ae48..1850cf1e 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -296,15 +296,29 @@ def rl_step( ) group_ratio_new_old = torch.exp(lrn_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2) group_advantages_t = (adv_sum / tok_count.clamp(min=1e-6)).unsqueeze(1).unsqueeze(2).detach() + # Sum token weights per segment so GSPO respects normalization/overflow settings. + zero_weights = torch.zeros_like(tokens_weights) + weight_sum, _, _ = per_segment_sums( + batch.segment_ids, + masks_shifted, + tokens_weights, + zero_weights, + seq_parallel_group=seq_parallel_group, + ) + valid_mask = (tok_count > 0) & (weight_sum > 0) + valid_mask_3d = valid_mask.unsqueeze(1).unsqueeze(2) surr1 = group_ratio_new_old * group_advantages_t clamped_group_ratio = torch.clamp(group_ratio_new_old, 1 - config.epsilon_low, 1 + config.epsilon_high) - clamp_log_ratio_new_old_indicators = clamped_group_ratio != group_ratio_new_old + clamp_log_ratio_new_old_indicators = (clamped_group_ratio != group_ratio_new_old) & valid_mask_3d surr2 = clamped_group_ratio * group_advantages_t + sequence_weights = weight_sum.unsqueeze(1).unsqueeze(2) # If we have a sentinel or no segments, return a zero loss but keep graph if batch.sentinel or surr1.numel() == 0: policy_loss_total = new_logprobs[..., :1].sum() * 0.0 else: - policy_loss_total = -torch.min(surr1, surr2).sum() + mask_float = valid_mask_3d.to(dtype=surr1.dtype) + min_terms = torch.min(surr1, surr2) * mask_float * sequence_weights + policy_loss_total = -min_terms.sum() expanded_indicators = torch.zeros_like(masks_shifted, dtype=torch.float) # Expand per-sequence indicators to token-level across segment ranges # Flatten to 1-D so single-sequence cases don't produce 0-d tensors @@ -490,10 +504,20 @@ def calculate_advantages(row): assert len(df) == len(df_init) # Step 4: make token-level overflow and mean group length information - df["overflow"] = df.apply( - lambda row: [0.0] * len(row["overflow"]) if eos_token_id in row["input_ids"] else [1.0] * len(row["overflow"]), - axis=1, - ) + def _overflow_from_finish_reason(row): + length = len(row["overflow"]) + finish_reason = row.get("finish_reason") + if isinstance(finish_reason, str): + finish_reason = finish_reason.strip().lower() + if finish_reason == "length": + return [1.0] * length + if finish_reason in {"stop", "content_filter"}: + return [0.0] * length + if row.get("finished"): + return [0.0] * length + return [0.0] * length if eos_token_id in row["input_ids"] else [1.0] * length + + df["overflow"] = df.apply(_overflow_from_finish_reason, axis=1) df["group_tokens"] = df.apply(lambda row: [row["group_tokens"]] * len(row["input_ids"]), axis=1) df["num_labels"] = df.apply( lambda row: [sum(1 for label in row["labels"] if label != -100)] * len(row["input_ids"]), axis=1 From 0d569edee794320b609fed94cf65cdfa7f799dc0 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 10 Oct 2025 11:00:12 +0000 Subject: [PATCH 33/35] Extract finish reason --- pipelinerl/preprocess.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 372c7b5b..325a4cd9 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -154,6 +154,11 @@ def preprocess_dataset( entry = dict(data[i]) for k, v in preprocess(data[i]).items(): entry[k] = v + metadata = entry.get("metadata") + if isinstance(metadata, dict): + entry["finish_reason"] = metadata.get("finish_reason") + else: + entry["finish_reason"] = None dataset.append(entry) for entry in dataset: entry["model_version"] = entry["metadata"]["model_version"] @@ -666,4 +671,3 @@ def run_preprocessing_loop( if worker.is_alive(): worker.terminate() worker.join(timeout=1.0) - From 1c4e197fc3d3e199d2a58dc96e16181ee497e63e Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 10 Oct 2025 11:22:43 +0000 Subject: [PATCH 34/35] Update versionons --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f950d75d..e25e3ee0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,10 +14,10 @@ authors = [ ] dependencies = [ "torch>=2.6", - "vllm==0.8.3", + "vllm==0.8.5.post1", "accelerate==1.7.0", - "Tapeagents[finetune]==0.1.15", - "transformers==4.51.0", + "Tapeagents[finetune]==0.1.16", + "transformers==4.51.1" , "flash-attn==2.7.4.post1", "ring-flash-attn==0.1.6", "math-verify[antlr4_9_3]==0.7.0", From 52f22ad3e027a5c7f1b4fb6915c73e2099319762 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Fri, 10 Oct 2025 17:06:03 +0200 Subject: [PATCH 35/35] update vllm extentsions to be compatible with 0.8.5 --- pipelinerl/vllm0.py | 3 +-- pipelinerl/vllm1.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 92c51085..8cd023bd 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -228,8 +228,7 @@ async def _receive_weight_update(request: WeightUpdateRequest): await weight_update_manager.receive_weight_update(request) return {"status": "ok"} - model_config = await engine.get_model_config() - await init_app_state(engine, model_config, app.state, args) + await init_app_state(engine, engine_config, app.state, args) shutdown_task = await serve_http( app, sock, diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 80cba297..be98f76f 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -172,8 +172,7 @@ async def _receive_weight_update(request: WeightUpdateRequest): await weight_update_manager.receive_weight_update(request) return {"status": "ok"} - model_config = await engine.get_model_config() - await init_app_state(engine, model_config, app.state, args) + await init_app_state(engine, engine_config, app.state, args) shutdown_task = await serve_http( app, sock,