Skip to content

🦜Enhance repetition penalty reward for language that cannot be split by whitespace #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,33 @@ def cosine_scaled_reward(completions, solution, **kwargs):
return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language: str = "en"):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
language: Language of the text, defaults to `en`. Used to choose the way to split the text into n-grams.
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")

def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
if language == "en":
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)]), words
elif language == "zh":
from transformers.utils.import_utils import _is_package_available
if not _is_package_available("jieba"):
raise ValueError("Please install jieba to use Chinese language")
def zipngram(text: str, ngram_size: int):
import jieba
seg_list = list(jieba.cut(text))
return zip(*[seg_list[i:] for i in range(ngram_size)]), seg_list
else:
raise ValueError(f"Word splitting for language `{language}` is not yet implemented. Please implement your own zip-ngram function.")

def repetition_penalty_reward(completions, **kwargs) -> float:
"""
Expand All @@ -311,13 +323,16 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue

ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngram_array, words = zipngram(completion, ngram_size)

if len(words) < ngram_size:
rewards.append(0.0)
continue

for ng in ngram_array:
ngrams.add(ng)
total += 1

Expand Down
12 changes: 12 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,18 @@ def test_tag_count_rewards_missing_all_tags(self):
rewards = tag_count_reward(completion)
self.assertEqual(rewards[0], 0.0)

def test_full_repetition_with_language(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="en")
completions = [[{"content": "that that that that that"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.75])
# begin test for zh language
try: import jieba
except: self.skipTest("jieba is not installed")
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0, language="zh")
completions = [[{"content": "这个这个这个这个这个"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [-0.75])

class TestCodeFormat(unittest.TestCase):
def test_correct_python_format(self):
Expand Down