Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Judges] rlhflow pairwise judges #2548

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
import unittest

from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge
from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge, RLHFlowPairwiseJudge

from .testing_utils import RandomBinaryJudge, require_llm_blender

Expand Down Expand Up @@ -74,3 +75,23 @@ def test_pair_rm_judge_return_scores(self):
self.assertEqual(len(probs), 2)
self.assertTrue(all(isinstance(prob, float) for prob in probs))
self.assertTrue(all(0 <= prob <= 1 for prob in probs))

# skip this test on windows
@unittest.skipIf(sys.platform == "win32", "Skipping test on Windows")
def test_rlhflow_pairwise_judge(self):
judge = RLHFlowPairwiseJudge("TianqiLiuAI/RRM-0p2")
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions, batch_size=2)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
self.assertEqual(ranks, [0, 1])

# skip this test on windows
@unittest.skipIf(sys.platform == "win32", "Skipping test on Windows")
def test_rlhflow_pairwise_judge_return_scores(self):
judge = RLHFlowPairwiseJudge("TianqiLiuAI/RRM-0p2")
prompts, completions = self._get_prompts_and_pairwise_completions()
probs = judge.judge(prompts=prompts, completions=completions, return_scores=True, batch_size=2)
self.assertEqual(len(probs), 2)
self.assertTrue(all(isinstance(prob, float) for prob in probs))
self.assertTrue(all(0 <= prob <= 1 for prob in probs))
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"PRMTrainer",
"RewardConfig",
"RewardTrainer",
"RLHFlowPairwiseJudge",
"RLOOConfig",
"RLOOTrainer",
"SFTConfig",
Expand Down Expand Up @@ -175,6 +176,7 @@
PRMTrainer,
RewardConfig,
RewardTrainer,
RLHFlowPairwiseJudge,
RLOOConfig,
RLOOTrainer,
SFTConfig,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"HfPairwiseJudge",
"OpenAIPairwiseJudge",
"PairRMJudge",
"RLHFlowPairwiseJudge",
],
"kto_config": ["KTOConfig"],
"kto_trainer": ["KTOTrainer"],
Expand Down Expand Up @@ -115,6 +116,7 @@
HfPairwiseJudge,
OpenAIPairwiseJudge,
PairRMJudge,
RLHFlowPairwiseJudge,
)
from .kto_config import KTOConfig
from .kto_trainer import KTOTrainer
Expand Down
131 changes: 131 additions & 0 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, Union

import numpy as np
import torch
from accelerate import Accelerator
from huggingface_hub import InferenceClient
from transformers.utils import is_openai_available
Expand Down Expand Up @@ -455,3 +456,133 @@ def judge(
else:
output.append(0)
return output


class RLHFlowPairwiseJudge(BasePairwiseJudge):
"""
Pairwise judge based on RLHFlow-style preference models.

This judge uses a preference model trained with RLHFlow to compare two responses and determine which is better.

Args:
model_name (`str`): The name or path of the preference model to use.
device (`str`, optional): Device to load the model on. Defaults to "cuda" if available.
**model_kwargs: Additional keyword arguments to pass to AutoModelForCausalLM.from_pretrained()

Example:
```python
rlhflow_judge = RLHFlowPairwiseJudge("TianqiLiuAI/RRM-0p2")
prompts = ["What's the capital of France?"]
completions = [["Paris", "Lyon"]]
results = rlhflow_judge.judge(prompts, completions)
```
"""

def __init__(self, model_name: str, device: Optional[str] = None, **model_kwargs):
from transformers import AutoModelForCausalLM, AutoTokenizer

# Setup device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model and tokenizers
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device)
self.model.eval()

self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.tokenizer_plain = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.tokenizer_plain.chat_template = "\n{% for message in messages %}{% if loop.index0 % 2 == 0 %}\n\n<turn> user\n {{ message['content'] }}{% endif %}{% endfor %}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to override the chat template btw?


# Setup constants
self.prompt_template = "[CONTEXT] {context} [RESPONSE A] {response_A} [RESPONSE B] {response_B} \n"
self.token_id_A = self.tokenizer.encode("A", add_special_tokens=False)[0]
self.token_id_B = self.tokenizer.encode("B", add_special_tokens=False)[0]
self.device = device

def _process_batch(self, batch_prompts, batch_completions, temperature=1.0, return_scores=False):
# Process each prompt-completion pair in the batch
inputs = []
for prompt, completion_pair in zip(batch_prompts, batch_completions):
# Convert prompt to chat format
instruction = [{"role": "user", "content": prompt}]
context = self.tokenizer_plain.apply_chat_template(instruction, tokenize=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using trl.apply_chat_template from the trl data utils here. We've encountered several issues in the past when applying chat templates to partial sequences, and this approach would be more robust.

While one could argue that we control the chat template in this context, using trl.apply_chat_template ensures that any future modifications to the chat template won't introduce unexpected issues here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so both this and the below changes are how the RLHF model recommends to do the scoring... I can check if it works using the chat template


# Format prompt with completions
prompt = self.prompt_template.format(
context=context, response_A=completion_pair[0], response_B=completion_pair[1]
)
message = [{"role": "user", "content": prompt}]

# Prepare input for tokenization
input_text = self.tokenizer.apply_chat_template(message, tokenize=False).replace(
self.tokenizer.bos_token, ""
)
inputs.append(input_text)

# Batch tokenize all inputs
batch_inputs = self.tokenizer(
inputs,
return_tensors="pt",
add_special_tokens=False,
padding=True,
).to(self.device)

# Process entire batch at once
with torch.inference_mode():
outputs = self.model(**batch_inputs)

# Get logits for each sequence in batch
logits = outputs.logits[:, -1, [self.token_id_A, self.token_id_B]] / temperature
probs = torch.softmax(logits, dim=-1)

if return_scores:
return probs[:, 0].tolist() # Get probability for option A for all items in batch
else:
return torch.where(probs[:, 0] > 0.5, 0, 1).tolist() # rank is 0 if A is better, 1 if B is better

def judge(
self,
prompts: list[str],
completions: list[list[str]],
shuffle_order: bool = True,
return_scores: bool = False,
temperature: float = 1.0,
batch_size: int = 8,
) -> list[Union[int, float]]:
"""
Judge the completion pairs for the given prompts.

Args:
prompts (`List[str]`): List of prompts to judge.
completions (`List[List[str]]`): List of completion pairs for each prompt.
shuffle_order (`bool`, optional): Whether to shuffle the order of completions.
Defaults to True.
return_scores (`bool`, optional): If True, return probability scores instead
of binary choices. Defaults to False.
temperature (`float`, optional): Temperature for softmax scaling. Defaults to 1.0.
batch_size (`int`, optional): Batch size for processing. Defaults to 8.

Returns:
List[Union[int, float]]: List of preferred indices (0 or 1) or scores for
each prompt pair.
"""
if shuffle_order:
flip_mask = np.random.choice([True, False], size=len(prompts))
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]

results = []
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i : i + batch_size]
batch_completions = completions[i : i + batch_size]

batch_results = self._process_batch(batch_prompts, batch_completions, temperature, return_scores)
results.extend(batch_results)

# Flip results back if order was shuffled
if shuffle_order:
if return_scores:
results = [1.0 - result if flip else result for result, flip in zip(results, flip_mask)]
else:
results = [int(1 - result) if flip else int(result) for result, flip in zip(results, flip_mask)]

return results
Loading