Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 207 additions & 5 deletions tests/rl/grpo/dapo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,218 @@ def test_diff_loss(self):
grpo_loss.item(),
msg=(
"DAPO and GRPO loss values should be different for the same input"
" due to different configurations and potentially different"
" logic."
" due to different loss aggregation logics."
),
)

self.assertIn("kl", dapo_aux)
self.assertIn("kl", grpo_aux)
self.assertNotEqual(
dapo_aux["kl"], grpo_aux["kl"]
) # Expected as beta differs
self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term.


class TestDAPOConfigPostInit(parameterized.TestCase):

def test_valid_default(self):
"""Tests that default values pass validation."""
try:
dapo_lib.DAPOConfig()
except ValueError as e:
self.fail(f"DAPOConfig raised ValueError on default initialization: {e}")

@parameterized.named_parameters(
dict(testcase_name="custom_epsilons", epsilon=0.1, epsilon_high=0.15),
dict(testcase_name="epsilons_equal", epsilon=0.1, epsilon_high=0.1),
dict(
testcase_name="buffer_disabled",
overlong_buffer={"enable": False},
),
dict(testcase_name="buffer_none", overlong_buffer=None),
dict(
testcase_name="valid_buffer",
overlong_buffer={
"enable": True,
"overlong_buffer_length": 2000,
"overlong_buffer_penalty": 0.5,
"max_response_length": 10000,
},
),
)
def test_valid_configurations(self, **kwargs):
"""Tests various valid custom configurations."""
try:
dapo_lib.DAPOConfig(**kwargs)
except ValueError as e:
self.fail(f"DAPOConfig raised ValueError for valid case {kwargs}: {e}")

@parameterized.named_parameters(
dict(
testcase_name="invalid_epsilon_high",
config_kwargs=dict(epsilon=0.2, epsilon_high=0.1),
expected_regex=(
"epsilon_high must be greater than or equal to epsilon."
),
),
dict(
testcase_name="buffer_missing_length",
config_kwargs=dict(
overlong_buffer={
"enable": True,
"overlong_buffer_penalty": 1.0,
"max_response_length": 20480,
}
),
expected_regex=(
"overlong_buffer is enabled but missing.*overlong_buffer_length.*"
),
),
dict(
testcase_name="buffer_missing_penalty",
config_kwargs=dict(
overlong_buffer={
"enable": True,
"overlong_buffer_length": 4096,
"max_response_length": 20480,
}
),
expected_regex=(
"overlong_buffer is enabled but missing"
".*overlong_buffer_penalty.*"
),
),
dict(
testcase_name="buffer_missing_max_length",
config_kwargs=dict(
overlong_buffer={
"enable": True,
"overlong_buffer_length": 4096,
"overlong_buffer_penalty": 1.0,
}
),
expected_regex=(
"overlong_buffer is enabled but missing.*max_response_length.*"
),
),
dict(
testcase_name="buffer_length_is_none",
config_kwargs=dict(
overlong_buffer={
"enable": True,
"overlong_buffer_length": None,
"overlong_buffer_penalty": 1.0,
"max_response_length": 20480,
}
),
expected_regex=(
"overlong_buffer is enabled but missing.*overlong_buffer_length.*"
),
),
dict(
testcase_name="negative_penalty",
config_kwargs=dict(
overlong_buffer={
"enable": True,
"overlong_buffer_length": 4096,
"overlong_buffer_penalty": -0.5,
"max_response_length": 20480,
}
),
expected_regex="overlong_buffer_penalty must be non-negative",
),
)
def test_invalid_configurations(self, config_kwargs, expected_regex):
"""Tests various invalid configurations that should raise ValueError."""
with self.assertRaisesRegex(ValueError, expected_regex):
dapo_lib.DAPOConfig(**config_kwargs)


class RewardShapingTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.mock_cluster = mock.MagicMock()

def test_raises_error_on_none_buffer(self):
with self.assertRaisesRegex(
ValueError, "reward_shaping is called but with empty overlong_buffer."
):

dapo_lib.reward_shaping(
prompts=["test prompt"],
completions=["test completion"],
mode=self.mock_cluster.Mode,
overlong_buffer=None,
)

def test_raises_error_on_non_positive_buffer_length(self):
with self.assertRaisesRegex(
ValueError, "overlong_buffer_length must be positive."
):
dapo_lib.reward_shaping(
prompts=["test prompt"],
completions=["test completion"],
mode=self.mock_cluster.Mode,
overlong_buffer={
"overlong_buffer_length": 0,
"overlong_buffer_penalty": 10,
"max_response_length": 100,
},
)

@parameterized.named_parameters(
dict(
testcase_name="under_length",
lengths=[70],
expected_scores=[0.0],
),
dict(
testcase_name="at_expected_length",
lengths=[80],
expected_scores=[0.0],
),
dict(
testcase_name="in_buffer_zone",
lengths=[90],
expected_scores=[-5.0],
),
dict(
testcase_name="at_max_length",
lengths=[100],
expected_scores=[-10.0],
),
dict(
testcase_name="over_max_length",
lengths=[110],
expected_scores=[-15.0],
),
dict(
testcase_name="mixed_lengths",
lengths=[70, 80, 90, 100, 110],
expected_scores=[0.0, 0.0, -5.0, -10.0, -15.0],
),
dict(
testcase_name="zero_penalty",
lengths=[110],
expected_scores=[0.0],
penalty=0,
),
)
def test_reward_scores(self, lengths, expected_scores, penalty=10):
completions = ["a" * length for length in lengths]
overlong_buffer = {
"overlong_buffer_length": 20,
"overlong_buffer_penalty": penalty,
"max_response_length": 100,
}
# expected_response_length = 100 - 20 = 80

scores = dapo_lib.reward_shaping(
prompts=[""] * len(completions),
completions=completions,
mode=self.mock_cluster.Mode,
overlong_buffer=overlong_buffer,
)

self.assertSequenceAlmostEqual(expected_scores, scores, places=4)


if __name__ == "__main__":
Expand Down
Loading
Loading