-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Judges] rlhflow pairwise judges #2548
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
from typing import Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
from accelerate import Accelerator | ||
from huggingface_hub import InferenceClient | ||
from transformers.utils import is_openai_available | ||
|
@@ -455,3 +456,133 @@ def judge( | |
else: | ||
output.append(0) | ||
return output | ||
|
||
|
||
class RLHFlowPairwiseJudge(BasePairwiseJudge): | ||
""" | ||
Pairwise judge based on RLHFlow-style preference models. | ||
|
||
This judge uses a preference model trained with RLHFlow to compare two responses and determine which is better. | ||
|
||
Args: | ||
model_name (`str`): The name or path of the preference model to use. | ||
device (`str`, optional): Device to load the model on. Defaults to "cuda" if available. | ||
**model_kwargs: Additional keyword arguments to pass to AutoModelForCausalLM.from_pretrained() | ||
|
||
Example: | ||
```python | ||
rlhflow_judge = RLHFlowPairwiseJudge("TianqiLiuAI/RRM-0p2") | ||
prompts = ["What's the capital of France?"] | ||
completions = [["Paris", "Lyon"]] | ||
results = rlhflow_judge.judge(prompts, completions) | ||
``` | ||
""" | ||
|
||
def __init__(self, model_name: str, device: Optional[str] = None, **model_kwargs): | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
# Setup device | ||
if device is None: | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# Initialize model and tokenizers | ||
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device) | ||
self.model.eval() | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
self.tokenizer_plain = AutoTokenizer.from_pretrained(model_name, use_fast=True) | ||
self.tokenizer_plain.chat_template = "\n{% for message in messages %}{% if loop.index0 % 2 == 0 %}\n\n<turn> user\n {{ message['content'] }}{% endif %}{% endfor %}" | ||
|
||
# Setup constants | ||
self.prompt_template = "[CONTEXT] {context} [RESPONSE A] {response_A} [RESPONSE B] {response_B} \n" | ||
self.token_id_A = self.tokenizer.encode("A", add_special_tokens=False)[0] | ||
self.token_id_B = self.tokenizer.encode("B", add_special_tokens=False)[0] | ||
self.device = device | ||
|
||
def _process_batch(self, batch_prompts, batch_completions, temperature=1.0, return_scores=False): | ||
# Process each prompt-completion pair in the batch | ||
inputs = [] | ||
for prompt, completion_pair in zip(batch_prompts, batch_completions): | ||
# Convert prompt to chat format | ||
instruction = [{"role": "user", "content": prompt}] | ||
context = self.tokenizer_plain.apply_chat_template(instruction, tokenize=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recommend using While one could argue that we control the chat template in this context, using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so both this and the below changes are how the RLHF model recommends to do the scoring... I can check if it works using the chat template |
||
|
||
# Format prompt with completions | ||
prompt = self.prompt_template.format( | ||
context=context, response_A=completion_pair[0], response_B=completion_pair[1] | ||
) | ||
message = [{"role": "user", "content": prompt}] | ||
|
||
# Prepare input for tokenization | ||
input_text = self.tokenizer.apply_chat_template(message, tokenize=False).replace( | ||
self.tokenizer.bos_token, "" | ||
) | ||
inputs.append(input_text) | ||
|
||
# Batch tokenize all inputs | ||
batch_inputs = self.tokenizer( | ||
inputs, | ||
return_tensors="pt", | ||
add_special_tokens=False, | ||
padding=True, | ||
).to(self.device) | ||
|
||
# Process entire batch at once | ||
with torch.inference_mode(): | ||
outputs = self.model(**batch_inputs) | ||
|
||
# Get logits for each sequence in batch | ||
logits = outputs.logits[:, -1, [self.token_id_A, self.token_id_B]] / temperature | ||
probs = torch.softmax(logits, dim=-1) | ||
|
||
if return_scores: | ||
return probs[:, 0].tolist() # Get probability for option A for all items in batch | ||
else: | ||
return torch.where(probs[:, 0] > 0.5, 0, 1).tolist() # rank is 0 if A is better, 1 if B is better | ||
|
||
def judge( | ||
self, | ||
prompts: list[str], | ||
completions: list[list[str]], | ||
shuffle_order: bool = True, | ||
return_scores: bool = False, | ||
temperature: float = 1.0, | ||
batch_size: int = 8, | ||
) -> list[Union[int, float]]: | ||
""" | ||
Judge the completion pairs for the given prompts. | ||
|
||
Args: | ||
prompts (`List[str]`): List of prompts to judge. | ||
completions (`List[List[str]]`): List of completion pairs for each prompt. | ||
shuffle_order (`bool`, optional): Whether to shuffle the order of completions. | ||
Defaults to True. | ||
return_scores (`bool`, optional): If True, return probability scores instead | ||
of binary choices. Defaults to False. | ||
temperature (`float`, optional): Temperature for softmax scaling. Defaults to 1.0. | ||
batch_size (`int`, optional): Batch size for processing. Defaults to 8. | ||
|
||
Returns: | ||
List[Union[int, float]]: List of preferred indices (0 or 1) or scores for | ||
each prompt pair. | ||
""" | ||
if shuffle_order: | ||
flip_mask = np.random.choice([True, False], size=len(prompts)) | ||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)] | ||
|
||
results = [] | ||
for i in range(0, len(prompts), batch_size): | ||
batch_prompts = prompts[i : i + batch_size] | ||
batch_completions = completions[i : i + batch_size] | ||
|
||
batch_results = self._process_batch(batch_prompts, batch_completions, temperature, return_scores) | ||
results.extend(batch_results) | ||
|
||
# Flip results back if order was shuffled | ||
if shuffle_order: | ||
if return_scores: | ||
results = [1.0 - result if flip else result for result, flip in zip(results, flip_mask)] | ||
else: | ||
results = [int(1 - result) if flip else int(result) for result, flip in zip(results, flip_mask)] | ||
|
||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to override the chat template btw?