Skip to content

Commit 450d882

Browse files
sizhit2The tunix Authors
authored andcommitted
Add overlong reward shaping for DAPO.
Refactor rl_learner.compute_reward to use reward_manager Enable logging `algo_config` at learner init. PiperOrigin-RevId: 852838075
1 parent 6f1eb71 commit 450d882

14 files changed

+819
-142
lines changed

tests/rl/algorithm_config_test.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from absl.testing import parameterized
1717
from tunix.rl import algorithm_config
1818

19-
2019
class AlgorithmConfigTest(parameterized.TestCase):
2120

2221
def test_defaults_are_valid(self):
@@ -30,12 +29,19 @@ def test_defaults_are_valid(self):
3029
self.fail(f"Default AlgorithmConfig values raised ValueError: {e}")
3130

3231
@parameterized.named_parameters(
33-
dict(testcase_name="gspo_gae_ppo", algo="gspo", adv="gae", loss="ppo"),
32+
dict(
33+
testcase_name="gspo_gae_ppo", algo="gspo-token", adv="gae", loss="ppo"
34+
),
3435
dict(
3536
testcase_name="grpo_grpo_grpo", algo="grpo", adv="grpo", loss="grpo"
3637
),
3738
dict(testcase_name="ppo_gae_ppo", algo="ppo", adv="gae", loss="ppo"),
38-
dict(testcase_name="gspo_grpo_ppo", algo="gspo", adv="grpo", loss="ppo"),
39+
dict(
40+
testcase_name="gspo_grpo_ppo",
41+
algo="gspo-token",
42+
adv="grpo",
43+
loss="ppo",
44+
),
3945
)
4046
def test_valid_combinations(self, algo: str, adv: str, loss: str):
4147
"""Tests various valid combinations of core algorithm parameters."""
@@ -54,7 +60,6 @@ def test_valid_combinations(self, algo: str, adv: str, loss: str):
5460
)
5561

5662
@parameterized.named_parameters(
57-
dict(testcase_name="invalid_algo_dapo", value="dapo"),
5863
dict(testcase_name="invalid_algo_else", value="something_else"),
5964
)
6065
def test_invalid_algo_variant(self, value: str):
@@ -91,12 +96,14 @@ def test_kw_only_enforcement(self):
9196
"""Ensures that positional arguments are not allowed."""
9297
with self.assertRaises(TypeError):
9398
# Attempt to initialize with positional arguments
94-
algorithm_config.AlgorithmConfig("grpo", "grpo", "grpo")
99+
algorithm_config.AlgorithmConfig("grpo-token", "grpo", "grpo")
95100

96101
# Check that standard keyword initialization works
97102
try:
98103
algorithm_config.AlgorithmConfig(
99-
algo_variant="gspo", advantage_estimator="gae", policy_loss_fn="ppo"
104+
algo_variant="gspo-token",
105+
advantage_estimator="gae",
106+
policy_loss_fn="ppo",
100107
)
101108
except TypeError:
102109
self.fail("Keyword arguments failed for kw_only dataclass")
@@ -117,6 +124,24 @@ def test_field_assignment(self):
117124
config.algo_variant = "invalid_after_init"
118125
self.assertEqual(config.algo_variant, "invalid_after_init")
119126

127+
def test_config_logging(self):
128+
"""Tests that configuration is logged correctly upon initialization."""
129+
# assertLogs catches logs at the specified level or higher
130+
with self.assertLogs(level="INFO") as log:
131+
algorithm_config.AlgorithmConfig(
132+
algo_variant="gspo-token",
133+
advantage_estimator="gae",
134+
policy_loss_fn="ppo",
135+
)
136+
137+
# log.output is a list of strings like ['INFO:root:message...']
138+
full_log_output = "\n".join(log.output)
139+
140+
self.assertIn("Initializing AlgorithmConfig", full_log_output)
141+
self.assertIn("algo_variant: gspo", full_log_output)
142+
self.assertIn("advantage_estimator: gae", full_log_output)
143+
self.assertIn("policy_loss_fn: ppo", full_log_output)
144+
120145

121146
if __name__ == "__main__":
122147
absltest.main()

tests/rl/function_registry_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_custom_categories_instance(self):
5252
def test_empty_categories_instance(self):
5353
# Test-specific instance for empty categories
5454
registry = function_registry.FunctionRegistry(allowed_categories=[])
55-
self.assertLen(registry.list_categories(), 2)
55+
self.assertLen(registry.list_categories(), 3)
5656

5757
@parameterized.named_parameters(
5858
dict(

tests/rl/grpo/dapo_learner_test.py

Lines changed: 192 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,203 @@ def test_diff_loss(self):
101101
grpo_loss.item(),
102102
msg=(
103103
"DAPO and GRPO loss values should be different for the same input"
104-
" due to different configurations and potentially different"
105-
" logic."
104+
" due to different loss aggregation logics."
106105
),
107106
)
108107

109108
self.assertIn("kl", dapo_aux)
110109
self.assertIn("kl", grpo_aux)
111-
self.assertNotEqual(
112-
dapo_aux["kl"], grpo_aux["kl"]
113-
) # Expected as beta differs
110+
self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term.
111+
112+
113+
class TestDAPOConfigPostInit(parameterized.TestCase):
114+
115+
def test_valid_default(self):
116+
"""Tests that default values pass validation."""
117+
try:
118+
dapo_lib.DAPOConfig()
119+
except ValueError as e:
120+
self.fail(f"DAPOConfig raised ValueError on default initialization: {e}")
121+
122+
@parameterized.named_parameters(
123+
dict(testcase_name="custom_epsilons", epsilon=0.1, epsilon_high=0.15),
124+
dict(testcase_name="epsilons_equal", epsilon=0.1, epsilon_high=0.1),
125+
dict(
126+
testcase_name="buffer_disabled",
127+
overlong_buffer={"enable": False},
128+
),
129+
dict(testcase_name="buffer_none", overlong_buffer=None),
130+
dict(
131+
testcase_name="valid_buffer",
132+
overlong_buffer={
133+
"enable": True,
134+
"overlong_buffer_length": 2000,
135+
"overlong_buffer_penalty": 0.5,
136+
"max_response_length": 10000,
137+
},
138+
),
139+
)
140+
def test_valid_configurations(self, **kwargs):
141+
"""Tests various valid custom configurations."""
142+
try:
143+
dapo_lib.DAPOConfig(**kwargs)
144+
except ValueError as e:
145+
self.fail(f"DAPOConfig raised ValueError for valid case {kwargs}: {e}")
146+
147+
@parameterized.named_parameters(
148+
dict(
149+
testcase_name="invalid_epsilon_high",
150+
config_kwargs=dict(epsilon=0.2, epsilon_high=0.1),
151+
expected_regex=(
152+
"epsilon_high must be greater than or equal to epsilon."
153+
),
154+
),
155+
dict(
156+
testcase_name="buffer_missing_length",
157+
config_kwargs=dict(
158+
overlong_buffer={
159+
"enable": True,
160+
"overlong_buffer_penalty": 1.0,
161+
"max_response_length": 20480,
162+
}
163+
),
164+
expected_regex=(
165+
"overlong_buffer is enabled but missing.*overlong_buffer_length.*"
166+
),
167+
),
168+
dict(
169+
testcase_name="buffer_missing_penalty",
170+
config_kwargs=dict(
171+
overlong_buffer={
172+
"enable": True,
173+
"overlong_buffer_length": 4096,
174+
"max_response_length": 20480,
175+
}
176+
),
177+
expected_regex=(
178+
"overlong_buffer is enabled but missing"
179+
".*overlong_buffer_penalty.*"
180+
),
181+
),
182+
dict(
183+
testcase_name="buffer_missing_max_length",
184+
config_kwargs=dict(
185+
overlong_buffer={
186+
"enable": True,
187+
"overlong_buffer_length": 4096,
188+
"overlong_buffer_penalty": 1.0,
189+
}
190+
),
191+
expected_regex=(
192+
"overlong_buffer is enabled but missing.*max_response_length.*"
193+
),
194+
),
195+
dict(
196+
testcase_name="buffer_length_is_none",
197+
config_kwargs=dict(
198+
overlong_buffer={
199+
"enable": True,
200+
"overlong_buffer_length": None,
201+
"overlong_buffer_penalty": 1.0,
202+
"max_response_length": 20480,
203+
}
204+
),
205+
expected_regex=(
206+
"overlong_buffer is enabled but missing.*overlong_buffer_length.*"
207+
),
208+
),
209+
dict(
210+
testcase_name="negative_penalty",
211+
config_kwargs=dict(
212+
overlong_buffer={
213+
"enable": True,
214+
"overlong_buffer_length": 4096,
215+
"overlong_buffer_penalty": -0.5,
216+
"max_response_length": 20480,
217+
}
218+
),
219+
expected_regex="overlong_buffer_penalty must be non-negative",
220+
),
221+
)
222+
def test_invalid_configurations(self, config_kwargs, expected_regex):
223+
"""Tests various invalid configurations that should raise ValueError."""
224+
with self.assertRaisesRegex(ValueError, expected_regex):
225+
dapo_lib.DAPOConfig(**config_kwargs)
226+
227+
228+
class RewardShapingTest(parameterized.TestCase):
229+
230+
def setUp(self):
231+
super().setUp()
232+
self.mock_cluster = mock.MagicMock()
233+
234+
def test_raises_error_on_none_buffer(self):
235+
with self.assertRaisesRegex(
236+
ValueError, "reward_shaping is called but with empty overlong_buffer."
237+
):
238+
239+
dapo_lib.reward_shaping(
240+
prompts=["test prompt"],
241+
completions=["test completion"],
242+
mode=self.mock_cluster.Mode,
243+
overlong_buffer=None,
244+
)
245+
246+
@parameterized.named_parameters(
247+
dict(
248+
testcase_name="under_length",
249+
lengths=[70],
250+
expected_scores=[0.0],
251+
),
252+
dict(
253+
testcase_name="at_expected_length",
254+
lengths=[80],
255+
expected_scores=[0.0],
256+
),
257+
dict(
258+
testcase_name="in_buffer_zone",
259+
lengths=[90],
260+
expected_scores=[-5.0],
261+
),
262+
dict(
263+
testcase_name="at_max_length",
264+
lengths=[100],
265+
expected_scores=[-10.0],
266+
),
267+
dict(
268+
testcase_name="over_max_length",
269+
lengths=[110],
270+
expected_scores=[-15.0],
271+
),
272+
dict(
273+
testcase_name="mixed_lengths",
274+
lengths=[70, 80, 90, 100, 110],
275+
expected_scores=[0.0, 0.0, -5.0, -10.0, -15.0],
276+
),
277+
dict(
278+
testcase_name="zero_penalty",
279+
lengths=[110],
280+
expected_scores=[0.0],
281+
penalty=0,
282+
),
283+
)
284+
def test_reward_scores(self, lengths, expected_scores, penalty=10):
285+
completions = ["a" * length for length in lengths]
286+
overlong_buffer = {
287+
"overlong_buffer_length": 20,
288+
"overlong_buffer_penalty": penalty,
289+
"max_response_length": 100,
290+
}
291+
# expected_response_length = 100 - 20 = 80
292+
293+
scores = dapo_lib.reward_shaping(
294+
prompts=[""] * len(completions),
295+
completions=completions,
296+
mode=self.mock_cluster.Mode,
297+
overlong_buffer=overlong_buffer,
298+
)
299+
300+
self.assertSequenceAlmostEqual(expected_scores, scores, places=4)
114301

115302

116303
if __name__ == "__main__":

0 commit comments

Comments
 (0)