diff --git a/tests/rl/algorithm_config_test.py b/tests/rl/algorithm_config_test.py index a3a62a3ea..b56dce735 100644 --- a/tests/rl/algorithm_config_test.py +++ b/tests/rl/algorithm_config_test.py @@ -16,7 +16,6 @@ from absl.testing import parameterized from tunix.rl import algorithm_config - class AlgorithmConfigTest(parameterized.TestCase): def test_defaults_are_valid(self): @@ -54,7 +53,6 @@ def test_valid_combinations(self, algo: str, adv: str, loss: str): ) @parameterized.named_parameters( - dict(testcase_name="invalid_algo_dapo", value="dapo"), dict(testcase_name="invalid_algo_else", value="something_else"), ) def test_invalid_algo_variant(self, value: str): @@ -118,5 +116,22 @@ def test_field_assignment(self): self.assertEqual(config.algo_variant, "invalid_after_init") +def test_config_logging(self): + """Tests that configuration is logged correctly upon initialization.""" + # assertLogs catches logs at the specified level or higher + with self.assertLogs(level="INFO") as log: + algorithm_config.AlgorithmConfig( + algo_variant="gspo", advantage_estimator="gae", policy_loss_fn="ppo" + ) + + # log.output is a list of strings like ['INFO:root:message...'] + full_log_output = "\n".join(log.output) + + self.assertIn("Initializing AlgorithmConfig", full_log_output) + self.assertIn("algo_variant: gspo", full_log_output) + self.assertIn("advantage_estimator: gae", full_log_output) + self.assertIn("policy_loss_fn: ppo", full_log_output) + + if __name__ == "__main__": absltest.main() diff --git a/tests/rl/function_registry_test.py b/tests/rl/function_registry_test.py index 2525e79f6..95a86ac78 100644 --- a/tests/rl/function_registry_test.py +++ b/tests/rl/function_registry_test.py @@ -52,7 +52,7 @@ def test_custom_categories_instance(self): def test_empty_categories_instance(self): # Test-specific instance for empty categories registry = function_registry.FunctionRegistry(allowed_categories=[]) - self.assertLen(registry.list_categories(), 2) + self.assertLen(registry.list_categories(), 3) @parameterized.named_parameters( dict( diff --git a/tests/rl/grpo/dapo_learner_test.py b/tests/rl/grpo/dapo_learner_test.py index d7567ec2b..4860b6aa7 100644 --- a/tests/rl/grpo/dapo_learner_test.py +++ b/tests/rl/grpo/dapo_learner_test.py @@ -101,16 +101,203 @@ def test_diff_loss(self): grpo_loss.item(), msg=( "DAPO and GRPO loss values should be different for the same input" - " due to different configurations and potentially different" - " logic." + " due to different loss aggregation logics." ), ) self.assertIn("kl", dapo_aux) self.assertIn("kl", grpo_aux) - self.assertNotEqual( - dapo_aux["kl"], grpo_aux["kl"] - ) # Expected as beta differs + self.assertEqual(dapo_aux["kl"], 0.0) # DAPO does not have KL term. + + +class TestDAPOConfigPostInit(parameterized.TestCase): + + def test_valid_default(self): + """Tests that default values pass validation.""" + try: + dapo_lib.DAPOConfig() + except ValueError as e: + self.fail(f"DAPOConfig raised ValueError on default initialization: {e}") + + @parameterized.named_parameters( + dict(testcase_name="custom_epsilons", epsilon=0.1, epsilon_high=0.15), + dict(testcase_name="epsilons_equal", epsilon=0.1, epsilon_high=0.1), + dict( + testcase_name="buffer_disabled", + overlong_buffer={"enable": False}, + ), + dict(testcase_name="buffer_none", overlong_buffer=None), + dict( + testcase_name="valid_buffer", + overlong_buffer={ + "enable": True, + "overlong_buffer_length": 2000, + "overlong_buffer_penalty": 0.5, + "max_response_length": 10000, + }, + ), + ) + def test_valid_configurations(self, **kwargs): + """Tests various valid custom configurations.""" + try: + dapo_lib.DAPOConfig(**kwargs) + except ValueError as e: + self.fail(f"DAPOConfig raised ValueError for valid case {kwargs}: {e}") + + @parameterized.named_parameters( + dict( + testcase_name="invalid_epsilon_high", + config_kwargs=dict(epsilon=0.2, epsilon_high=0.1), + expected_regex=( + "epsilon_high must be greater than or equal to epsilon." + ), + ), + dict( + testcase_name="buffer_missing_length", + config_kwargs=dict( + overlong_buffer={ + "enable": True, + "overlong_buffer_penalty": 1.0, + "max_response_length": 20480, + } + ), + expected_regex=( + "overlong_buffer is enabled but missing.*overlong_buffer_length.*" + ), + ), + dict( + testcase_name="buffer_missing_penalty", + config_kwargs=dict( + overlong_buffer={ + "enable": True, + "overlong_buffer_length": 4096, + "max_response_length": 20480, + } + ), + expected_regex=( + "overlong_buffer is enabled but missing" + ".*overlong_buffer_penalty.*" + ), + ), + dict( + testcase_name="buffer_missing_max_length", + config_kwargs=dict( + overlong_buffer={ + "enable": True, + "overlong_buffer_length": 4096, + "overlong_buffer_penalty": 1.0, + } + ), + expected_regex=( + "overlong_buffer is enabled but missing.*max_response_length.*" + ), + ), + dict( + testcase_name="buffer_length_is_none", + config_kwargs=dict( + overlong_buffer={ + "enable": True, + "overlong_buffer_length": None, + "overlong_buffer_penalty": 1.0, + "max_response_length": 20480, + } + ), + expected_regex=( + "overlong_buffer is enabled but missing.*overlong_buffer_length.*" + ), + ), + dict( + testcase_name="negative_penalty", + config_kwargs=dict( + overlong_buffer={ + "enable": True, + "overlong_buffer_length": 4096, + "overlong_buffer_penalty": -0.5, + "max_response_length": 20480, + } + ), + expected_regex="overlong_buffer_penalty must be non-negative", + ), + ) + def test_invalid_configurations(self, config_kwargs, expected_regex): + """Tests various invalid configurations that should raise ValueError.""" + with self.assertRaisesRegex(ValueError, expected_regex): + dapo_lib.DAPOConfig(**config_kwargs) + + +class RewardShapingTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_cluster = mock.MagicMock() + + def test_raises_error_on_none_buffer(self): + with self.assertRaisesRegex( + ValueError, "reward_shaping is called but with empty overlong_buffer." + ): + + dapo_lib.reward_shaping( + prompts=["test prompt"], + completions=["test completion"], + mode=self.mock_cluster.Mode, + overlong_buffer=None, + ) + + @parameterized.named_parameters( + dict( + testcase_name="under_length", + lengths=[70], + expected_scores=[0.0], + ), + dict( + testcase_name="at_expected_length", + lengths=[80], + expected_scores=[0.0], + ), + dict( + testcase_name="in_buffer_zone", + lengths=[90], + expected_scores=[-5.0], + ), + dict( + testcase_name="at_max_length", + lengths=[100], + expected_scores=[-10.0], + ), + dict( + testcase_name="over_max_length", + lengths=[110], + expected_scores=[-15.0], + ), + dict( + testcase_name="mixed_lengths", + lengths=[70, 80, 90, 100, 110], + expected_scores=[0.0, 0.0, -5.0, -10.0, -15.0], + ), + dict( + testcase_name="zero_penalty", + lengths=[110], + expected_scores=[0.0], + penalty=0, + ), + ) + def test_reward_scores(self, lengths, expected_scores, penalty=10): + completions = ["a" * length for length in lengths] + overlong_buffer = { + "overlong_buffer_length": 20, + "overlong_buffer_penalty": penalty, + "max_response_length": 100, + } + # expected_response_length = 100 - 20 = 80 + + scores = dapo_lib.reward_shaping( + prompts=[""] * len(completions), + completions=completions, + mode=self.mock_cluster.Mode, + overlong_buffer=overlong_buffer, + ) + + self.assertSequenceAlmostEqual(expected_scores, scores, places=4) if __name__ == "__main__": diff --git a/tests/rl/reward_test.py b/tests/rl/reward_test.py new file mode 100644 index 000000000..5d3674c2f --- /dev/null +++ b/tests/rl/reward_test.py @@ -0,0 +1,245 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import inspect +from typing import Any, List +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import mock +import numpy as np +import numpy.testing as npt +from tunix.rl import algorithm_config as algo_config_lib +from tunix.rl import reward + + +# --- Test Reward Functions --- +def len_reward( + prompts: List[str], completions: List[str], **kwargs: Any +) -> List[float]: + del prompts, kwargs # Unused + res = [float(len(c)) for c in completions] + return res + + +len_reward.__name__ = "len_reward" + + +def prompt_len_reward( + prompts: List[str], + completions: List[str], + custom_param: float = 1.0, + **kwargs: Any, +) -> List[float]: + del completions, kwargs # Unused + res = [custom_param * len(p) for p in prompts] + return res + + +prompt_len_reward.__name__ = "prompt_len_reward" + + +def nan_reward( + prompts: List[str], completions: List[str], **kwargs: Any +) -> List[float]: + del completions, kwargs # Unused + return [np.nan] * len(prompts) + + +nan_reward.__name__ = "nan_reward" + + +@dataclasses.dataclass(slots=True, kw_only=True) +class TestAlgoConfig(algo_config_lib.AlgorithmConfig): + """Test Algorithm Config.""" + + reward_manager: str = "sequence-level" + custom_param: float = 2.0 + + +# --- Test Class --- +class SequenceRewardManagerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_algo_config = TestAlgoConfig() + self.prompts = ["p1", "p22"] + self.completions = ["c1_long", "c2"] + + def test_initialization(self): + manager = reward.SequenceRewardManager( + reward_fns=len_reward, + algo_config=self.test_algo_config, + ) + self.assertEqual(manager.reward_fns, [len_reward]) + self.assertEqual(manager.algo_config, self.test_algo_config) + + def test_single_reward_fn(self): + manager = reward.SequenceRewardManager( + reward_fns=[len_reward], + algo_config=self.test_algo_config, + ) + rewards_info = manager( + self.prompts, + self.completions, + ) + + expected_rewards = np.array([float(len("c1_long")), float(len("c2"))]) + np.testing.assert_array_equal(rewards_info["rewards"], expected_rewards) + logging.info("rewards_info[log_metrics]: %s", rewards_info["log_metrics"]) + self.assertLen(rewards_info["log_metrics"], 6) + + def test_multiple_reward_fns(self): + manager = reward.SequenceRewardManager( + reward_fns=[len_reward, prompt_len_reward], + algo_config=self.test_algo_config, + ) + rewards_info = manager( + self.prompts, + self.completions, + ) + + # custom_param is 2.0 from test_algo_config + r1 = np.array(len_reward(self.prompts, self.completions)) + r2 = np.array( + prompt_len_reward(self.prompts, self.completions, custom_param=2.0) + ) + expected_rewards = r1 + r2 + rewards_matrix = np.array([r1, r2]) + np.testing.assert_array_almost_equal( + rewards_info["rewards"], expected_rewards + ) + test_metrics = rewards_info["log_metrics"] + for metric_name, v in test_metrics.items(): + if metric_name.startswith("rewards/"): + self.assertLen(v[0], 2) + npt.assert_allclose( + test_metrics["rewards/sum"][0], + expected_rewards, + err_msg="rewards/sum mismatch", + ) + npt.assert_allclose( + test_metrics["rewards/len_reward"][0], + r1, + err_msg="rewards/len_reward mismatch", + ) + npt.assert_allclose( + test_metrics["rewards/prompt_len_reward"][0], + r2, + err_msg="rewards/prompt_len_reward mismatch", + ) + for col_idx in range(rewards_matrix.shape[0]): + npt.assert_allclose( + test_metrics["rewards/min"][0][col_idx], + np.min(rewards_matrix[:, col_idx]), + ) + npt.assert_allclose( + test_metrics["rewards/max"][0][col_idx], + np.max(rewards_matrix[:, col_idx]), + ) + + def test_algo_config_param_passing(self): + # Mock the reward function to spy on its call arguments + mock_fn = mock.Mock(wraps=prompt_len_reward) + mock_fn.__name__ = prompt_len_reward.__name__ + # Restore the signature for introspection + mock_fn.__signature__ = inspect.signature(prompt_len_reward) + + manager = reward.SequenceRewardManager( + reward_fns=[mock_fn], + algo_config=self.test_algo_config, + ) + manager( + self.prompts, + self.completions, + ) + + mock_fn.assert_called_once() + _, kwargs = mock_fn.call_args + self.assertEqual(kwargs["custom_param"], 2.0) + self.assertNotIn( + "another_param", kwargs + ) # Not in prompt_len_reward signature + + def test_nan_handling(self): + manager = reward.SequenceRewardManager( + reward_fns=[len_reward, nan_reward], + algo_config=self.test_algo_config, + ) + rewards_info = manager( + self.prompts, + self.completions, + ) + # np.nansum should treat nan as 0 for summation + expected_rewards = np.array([float(len(c)) for c in self.completions]) + np.testing.assert_array_almost_equal( + rewards_info["rewards"], expected_rewards + ) + # Check logged metrics for NaN + test_metrics = rewards_info["log_metrics"] + self.assertTrue(np.isnan(test_metrics["rewards/nan_reward"][0]).all()) + np.testing.assert_allclose( + test_metrics["rewards/sum"][0], + expected_rewards, + err_msg="rewards/sum mismatch", + ) + + @parameterized.named_parameters( + dict( + testcase_name="reward_fn_returns_none", + reward_fns=[lambda prompts, completions, **kw: None], + expected_regex="Failed to obtain result.*Result is None", + error_type=RuntimeError, + ), + dict( + testcase_name="reward_fn_bad_length", + reward_fns=[ + lambda prompts, completions, **kw: [1.0] * (len(prompts) + 1) + ], + expected_regex="Length mismatch", + error_type=RuntimeError, + ), + ) + def test_errors( + self, expected_regex, error_type, kwargs=None, reward_fns=None + ): + if reward_fns is None: + reward_fns = [len_reward] + for i, fn in enumerate(reward_fns): + if not hasattr(fn, "__name__"): + fn.__name__ = f"test_fn_{i}" + + manager = reward.SequenceRewardManager( + reward_fns=reward_fns, + algo_config=self.test_algo_config, + ) + with self.assertRaisesRegex(error_type, expected_regex): + manager( + self.prompts, + self.completions, + **(kwargs or {}), + ) + + +def test_no_reward_fns_raises_error(self): + with self.assertRaisesRegex(ValueError, "reward_fns cannot be empty"): + reward.SequenceRewardManager( + reward_fns=[], + algo_config=self.test_algo_config, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/rl/algorithm_config.py b/tunix/rl/algorithm_config.py index 452e5c2f5..0a4b178cf 100644 --- a/tunix/rl/algorithm_config.py +++ b/tunix/rl/algorithm_config.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses - +from absl import logging @dataclasses.dataclass(slots=True, kw_only=True) class AlgorithmConfig: @@ -28,12 +28,14 @@ class AlgorithmConfig: algo_variant: str = "grpo" advantage_estimator: str = "grpo" policy_loss_fn: str = "grpo" + reward_manager: str = "sequence-level" def __post_init__(self): valid_algo_variants = [ "grpo", - "gspo", + "gspo-token", "ppo", + "dapo", ] valid_advantage_estimators = ["grpo", "gae"] valid_policy_loss_fns = ["grpo", "ppo"] @@ -52,3 +54,13 @@ def __post_init__(self): f"policy_loss_fn must be one of {valid_policy_loss_fns}." f" Received: {self.policy_loss_fn}" ) + + # Automatically prints configuration upon initialization. + self.print_config() + + def print_config(self): + """Prints all configuration fields, working dynamically for child classes.""" + logging.info(f"Initializing {self.__class__.__name__}:") + for field in dataclasses.fields(self): + value = getattr(self, field.name) + logging.info(f" {field.name}: {value}") diff --git a/tunix/rl/experimental/agentic_grpo_learner.py b/tunix/rl/experimental/agentic_grpo_learner.py index 9acbc6544..1d739d09a 100644 --- a/tunix/rl/experimental/agentic_grpo_learner.py +++ b/tunix/rl/experimental/agentic_grpo_learner.py @@ -166,13 +166,6 @@ def __init__( policy_loss_fn = function_registry.get_policy_loss_fn( self.algo_config.policy_loss_fn ) - logging.info( - "algo_config.policy_loss_fn: %s", self.algo_config.policy_loss_fn - ) - logging.info("type(policy_loss_fn): %s", type(policy_loss_fn)) - - # Log the string representation of the callable - logging.info("repr(policy_loss_fn): %r", policy_loss_fn) loss_fn = lambda model, train_example, algo_config: policy_loss_fn( model, train_example, diff --git a/tunix/rl/experimental/agentic_rl_learner.py b/tunix/rl/experimental/agentic_rl_learner.py index 8f6ad8310..688b7898e 100644 --- a/tunix/rl/experimental/agentic_rl_learner.py +++ b/tunix/rl/experimental/agentic_rl_learner.py @@ -33,6 +33,8 @@ import numpy as np from tunix.rl import algorithm_config as algo_config_lib from tunix.rl import common +from tunix.rl import function_registry +from tunix.rl import reward from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl import utils as rl_utils from tunix.rl.agentic import utils as agentic_utils @@ -123,6 +125,14 @@ def __init__( """ self.rl_cluster = rl_cluster self.algo_config = algo_config + + reward_manager_fn = function_registry.get_reward_manager( + algo_config.reward_manager + ) + self.reward_manager = reward_manager_fn( + reward_fns=reward_fns, + algo_config=algo_config, + ) self.reward_fns = ( [reward_fns] if not isinstance(reward_fns, Sequence) else reward_fns ) @@ -187,7 +197,7 @@ def _compute_rewards( mode: rl_cluster_lib.Mode, expected_step: int | None = None, **kwargs, - ) -> jax.Array: + ) -> np.ndarray: """Computes the rewards for completions using the provided reward functions. Args: @@ -211,62 +221,26 @@ def _compute_rewards( raise ValueError(f"kwargs already contains mode as a key: {kwargs}") kwargs["mode"] = str(mode) - num_prompts = len(prompts) - num_reward_fns = len(self.reward_fns) - rewards = np.zeros((num_prompts, num_reward_fns)) - - # Compute all rewards for each prompt-completion pair. - for i, reward_fn in enumerate(self.reward_fns): - r = reward_fn(prompts=prompts, completions=completions, **kwargs) - - if r is None: - raise RuntimeError( - f"Failed to obtain result from {reward_fn.__name__}. Result is" - " None." - ) - if isinstance(r, list) and len(r) != len(prompts): - raise RuntimeError( - f"Length mismatch after {reward_fn.__name__}: " - f"len(r)={len(r)}, len(prompts)={num_prompts}. " - f"Content of r: {r}" - ) + rewards_info = self.reward_manager( + prompts=prompts, + completions=completions, + **kwargs, + ) - rewards[:, i] = np.array(r) - - # Sum rewards across all reward functions for each prompt. - sum_rewards = np.nansum(rewards, axis=1) - - # Log all metrics in a single loop - for j, (prompt, completion) in enumerate(zip(prompts, completions)): - metrics_to_log = {} - - # Log prompts and completions. - metrics_to_log["prompts"] = (prompt, None) - metrics_to_log["completions"] = (completion, None) - - # Log the summed rewards for this trajectory. - trajectory_sum = sum_rewards[j] - metrics_to_log["rewards/sum"] = (trajectory_sum, np.mean) - metrics_to_log["rewards/min"] = (np.min(rewards[j]), np.min) - metrics_to_log["rewards/max"] = (np.max(rewards[j]), np.max) - - # Log individual rewards for this trajectory - for i, reward_fn in enumerate(self.reward_fns): - metric_name = f"rewards/{reward_fn.__name__}" - metrics_to_log[metric_name] = (rewards[j, i], np.mean) - - # Log all metrics for this trajectory in one call - if expected_step is not None: - # Pass the expected_step explicitly because it is calculated based on - # the batch index (predicted step) to align metrics with the correct - # training step in the asynchronous execution. - self.rl_cluster.buffer_metrics_async( - metrics_to_log, mode=mode, step=expected_step - ) - else: - self.rl_cluster.buffer_metrics_async(metrics_to_log, mode=mode) + # Log all metrics for this trajectory in one call + if expected_step is not None: + # Pass the expected_step explicitly because it is calculated based on + # the batch index (predicted step) to align metrics with the correct + # training step in the asynchronous execution. + self.rl_cluster.buffer_metrics_async( + rewards_info["log_metrics"], mode=mode, step=expected_step + ) + else: + self.rl_cluster.buffer_metrics_async( + rewards_info["log_metrics"], mode=mode + ) - return jnp.array(sum_rewards) + return rewards_info["rewards"] def _create_micro_batch_iterator( self, diff --git a/tunix/rl/function_registry.py b/tunix/rl/function_registry.py index a90e2ace3..57204dad6 100644 --- a/tunix/rl/function_registry.py +++ b/tunix/rl/function_registry.py @@ -18,6 +18,7 @@ _POLICY_LOSS_FN_CATEGORY = "policy_loss_fn" _ADVANTAGE_ESTIMATOR_CATEGORY = "advantage_estimator" +_REWARD_MANAGER_CATEGORY = "reward_manager" class FunctionRegistry: @@ -26,6 +27,7 @@ class FunctionRegistry: DEFAULT_ALLOWED_CATEGORIES: FrozenSet[str] = frozenset({ _POLICY_LOSS_FN_CATEGORY, _ADVANTAGE_ESTIMATOR_CATEGORY, + _REWARD_MANAGER_CATEGORY, }) def __init__(self, allowed_categories: Optional[Iterable[str]] = None): @@ -135,3 +137,15 @@ def register_advantage_estimator( ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Returns a decorator to register an advantage estimator function by name.""" return default_registry.register(_ADVANTAGE_ESTIMATOR_CATEGORY, name) + + +def get_reward_manager(name: str) -> Callable[..., Any]: + """Returns the reward manager function by name.""" + return default_registry.get(_REWARD_MANAGER_CATEGORY, name) + + +def register_reward_manager( + name: str, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Returns a decorator to register a reward manager function by name.""" + return default_registry.register(_REWARD_MANAGER_CATEGORY, name) diff --git a/tunix/rl/grpo/dapo_learner.py b/tunix/rl/grpo/dapo_learner.py index 972316b14..8b318d1ee 100644 --- a/tunix/rl/grpo/dapo_learner.py +++ b/tunix/rl/grpo/dapo_learner.py @@ -29,7 +29,7 @@ class DAPOConfig(grpo_learner_lib.GRPOConfig): """Configuration for DAPO. Attributes: - algo_variant: The core algorithm variant to use. + algo_variant: The algorithm variant to use. advantage_estimator: The advantage estimator to use. policy_loss_fn: The policy loss function to use. loss_agg_mode: The aggregation mode for the loss function. @@ -56,6 +56,7 @@ class DAPOConfig(grpo_learner_lib.GRPOConfig): advantage_estimator: str = "grpo" policy_loss_fn: str = "grpo" loss_agg_mode: str = "token-mean" + reward_manager: str = "sequence-level" num_generations: int = 2 num_iterations: int = 1 beta: None = None # No KL term. @@ -64,10 +65,65 @@ class DAPOConfig(grpo_learner_lib.GRPOConfig): dynamic_sampling: bool = True # TODO(sizhi): Add dynamic sampling. overlong_buffer: Optional[Dict[str, Any]] = dataclasses.field( default_factory=lambda: { - "buffer_len": 1024, - "float": 1.0, + "enable": True, + "overlong_buffer_length": 4096, # Threshold before penalties apply. + "overlong_buffer_penalty": 1.0, + "max_response_length": 20480, # Hard maximum generation length. } - ) # TODO(sizhi): Add overlong buffer. + ) + + def __post_init__(self): + if self.epsilon_high < self.epsilon: + raise ValueError("epsilon_high must be greater than or equal to epsilon.") + + if self.overlong_buffer is not None and self.overlong_buffer.get("enable"): + buffer = self.overlong_buffer + required = [ + "overlong_buffer_length", + "overlong_buffer_penalty", + "max_response_length", + ] + + missing = [k for k in required if buffer.get(k) is None] + if missing: + raise ValueError(f"overlong_buffer is enabled but missing: {missing}") + + if buffer["overlong_buffer_penalty"] < 0: + raise ValueError("overlong_buffer_penalty must be non-negative.") + + if buffer["overlong_buffer_length"] <= 0: + raise ValueError("overlong_buffer_length must be positive.") + + if buffer["max_response_length"] <= 0: + raise ValueError("max_response_length must be positive.") + + +def reward_shaping( + prompts: List[str], + completions: List[str], + mode: rl_cluster_lib.Mode, + overlong_buffer: Dict[str, Any] | None = None, + **kwargs, +) -> List[float]: + """Reward shaping function for DAPO.""" + del prompts, mode, kwargs + if overlong_buffer is None: + raise ValueError("reward_shaping is called but with empty overlong_buffer.") + + overlong_buffer_length = overlong_buffer["overlong_buffer_length"] + overlong_buffer_penalty = overlong_buffer["overlong_buffer_penalty"] + max_response_length = overlong_buffer["max_response_length"] + + expected_response_length = max_response_length - overlong_buffer_length + scores = [] + for completion in completions: + output_length = len(completion) + exceed_length = output_length - expected_response_length + overlong_reward = min( + -exceed_length / overlong_buffer_length * overlong_buffer_penalty, 0 + ) + scores.append(overlong_reward) + return scores class DAPOLearner(grpo_learner_lib.GrpoLearner[DAPOConfig]): @@ -82,6 +138,11 @@ def __init__( data_shuffle_seed: int | None = None, ): """Initializes the `DAPOLearner`.""" + reward_fns = ( + [reward_fns] if not isinstance(reward_fns, Sequence) else reward_fns + ) + if algo_config.overlong_buffer and algo_config.overlong_buffer["enable"]: + reward_fns.append(reward_shaping) super().__init__( rl_cluster=rl_cluster, algo_config=algo_config, diff --git a/tunix/rl/grpo/grpo_learner.py b/tunix/rl/grpo/grpo_learner.py index cb15275a7..2333d6df0 100644 --- a/tunix/rl/grpo/grpo_learner.py +++ b/tunix/rl/grpo/grpo_learner.py @@ -44,10 +44,12 @@ class GRPOConfig(algo_config_lib.AlgorithmConfig): """Configuration for GRPO algorithms. Attributes: - algo_variant: The core algorithm variant to use. - advantage_estimator: The advantage estimator to use. - policy_loss_fn: The policy loss function to use. - loss_agg_mode: The aggregation mode for the loss function. + algo_variant: The algorithm variant to use. Default: `grpo`. + advantage_estimator: The advantage estimator to use. Default: `grpo`. + policy_loss_fn: The policy loss function to use. Default: `grpo`. + loss_agg_mode: The aggregation mode for the loss function. Default: + `sequence-mean-token-mean`. + reward_manager: The reward manager to use. Default: `sequence-level`. loss_algo: The loss algorithm to use. To be deprecated. num_generations: The number of times the policy generates multiple responses for a given prompt within a single training step. This corresponds to 'G' @@ -64,15 +66,16 @@ class GRPOConfig(algo_config_lib.AlgorithmConfig): normalized instead of per-response normalized as mentioned in the paper. For GSPO, we use gspo-token loss which is more flexible. - References: - - GRPO: https://arxiv.org/abs/2402.03300 - - GSPO: https://arxiv.org/abs/2507.18071 + References: + - GRPO: https://arxiv.org/abs/2402.03300 + - GSPO: https://arxiv.org/abs/2507.18071 """ algo_variant: str = "grpo" advantage_estimator: str = "grpo" policy_loss_fn: str = "grpo" loss_agg_mode: str = "sequence-mean-token-mean" + reward_manager: str = "sequence-level" loss_algo: ( str ) = ( # grpo or gspo-token # TODO(sizhi): Remove this option once gspo is @@ -112,8 +115,6 @@ class GRPOLearner(rl_learner.RLLearner[TGrpoConfig]): using a reward model, and then calculating a relative advantage based on the group's performance to update the policy. - References: - - https://arxiv.org/abs/2402.03300 """ def __init__( diff --git a/tunix/rl/ppo/ppo_learner.py b/tunix/rl/ppo/ppo_learner.py index a355c7ee9..168f8d99b 100644 --- a/tunix/rl/ppo/ppo_learner.py +++ b/tunix/rl/ppo/ppo_learner.py @@ -48,6 +48,10 @@ class PPOConfig(algo_config_lib.AlgorithmConfig): """Configuration for PPO learner. Attributes: + algo_variant: The algorithm variant to use. Default: `ppo`. + advantage_estimator: The advantage estimator to use. Default: `gae`. + policy_loss_fn: The policy loss function to use. Default: `ppo`. + reward_manager: The reward manager to use. Default: `sequence-level`. num_iterations: The number of optimization epochs per batch of rollouts. This corresponds to the number of times the policy updates its weights for a given batch of rollouts. @@ -75,6 +79,7 @@ class PPOConfig(algo_config_lib.AlgorithmConfig): algo_variant: str = "ppo" advantage_estimator: str = "gae" policy_loss_fn: str = "ppo" + reward_manager: str = "sequence-level" num_iterations: int = 1 # PPO loss and advantage computation configs. diff --git a/tunix/rl/reward.py b/tunix/rl/reward.py new file mode 100644 index 000000000..e15472b8a --- /dev/null +++ b/tunix/rl/reward.py @@ -0,0 +1,191 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reward output for RL.""" + +import abc +from dataclasses import asdict +import inspect +from typing import Any, Callable, Dict, List, Sequence +import numpy as np +from tunix.rl import algorithm_config as algo_config_lib +from tunix.rl import function_registry + +RewardFn = Callable[..., Any] + + +class AbstractRewardManager(abc.ABC): + """Abstract base class for managing and orchestrating multiple reward function outputs.""" + + def __init__( + self, + reward_fns: RewardFn | List[RewardFn], + algo_config: algo_config_lib.AlgorithmConfig, + ): + """Initializes the manager with a list of callable reward function objects. + + Args: + reward_fns: A list of reward functions or models. + algo_config: The algorithm config to use for reward function + configuration. + """ + self.reward_fns = ( + [reward_fns] if not isinstance(reward_fns, Sequence) else reward_fns + ) + + if not self.reward_fns: + raise ValueError( + "reward_fns cannot be empty. You must provide at least one reward" + " function." + ) + self.algo_config = algo_config + + @abc.abstractmethod + def __call__( + self, + prompts: List[str], + completions: List[str], + **kwargs, + ) -> Dict[str, Any]: + """Computes the rewards for completions using the provided reward functions. + + Args: + prompts: A list of input prompts. + completions: A list of generated text completions. + mode: The mode to use for logging metrics. + **kwargs: Additional keyword arguments passed to the reward functions. + + Returns: + A dictionary of rewards information, including the final rewards for + advantage computation and intermediate rewards for logging. + """ + pass + + +@function_registry.register_reward_manager("sequence-level") +class SequenceRewardManager(AbstractRewardManager): + """Reward manager for sequence-level rewards only.""" + + def __init__( + self, + reward_fns: RewardFn | List[RewardFn], + algo_config: algo_config_lib.AlgorithmConfig, + **kwargs, + ): + """Initializes the manager with a list of callable reward function objects.""" + super().__init__(reward_fns, algo_config) + + def __call__( + self, + prompts: List[str], + completions: List[str], + **kwargs, + ) -> Dict[str, Any]: + """Computes the rewards for completions using the provided reward function, and return the sequence-level rewards information for advantage computationand logging.""" + return self._compute_rewards(prompts, completions, **kwargs) + + def _compute_rewards( + self, + prompts: List[str], + completions: List[str], + **kwargs, + ) -> Dict[str, Any]: + """Computes the rewards for completions using the provided reward functions.""" + + algo_config_params = asdict(self.algo_config) + base_kwargs = kwargs.copy() + + num_prompts = len(prompts) + num_reward_fns = len(self.reward_fns) + rewards = np.zeros((num_prompts, num_reward_fns)) + + # Compute all rewards for each prompt-completion pair. + for i, reward_fn in enumerate(self.reward_fns): + # Update the kwargs with the algo_config parameters. + signature = inspect.signature(reward_fn) + reward_fn_config_params = {} + # Iterate over the function's expected parameters + for name, _ in signature.parameters.items(): + # Skip standard parameters that are always passed (self, prompts, completions, kwargs) + if name in ["self", "prompts", "completions", "kwargs"]: + continue + + # Check if the parameter name matches a key in the algo_config dict. If + # so, set the value to the algo_config parameter value, otherwise respect the value in the base_kwargs. + if name in algo_config_params and name not in base_kwargs: + reward_fn_config_params[name] = algo_config_params[name] + + call_kwargs = base_kwargs.copy() + call_kwargs.update(reward_fn_config_params) + + r = reward_fn(prompts=prompts, completions=completions, **call_kwargs) + + if r is None: + raise RuntimeError( + f"Failed to obtain result from {reward_fn.__name__}. Result is" + " None." + ) + if isinstance(r, list) and len(r) != len(prompts): + raise RuntimeError( + f"Length mismatch after {reward_fn.__name__}: " + f"len(r)={len(r)}, len(prompts)={num_prompts}. " + f"Content of r: {r}" + ) + + rewards[:, i] = np.array(r) + + # Sum rewards across all reward functions for each prompt. + sum_rewards = np.nansum(rewards, axis=1) + + # Prepare metrics for logging. + log_metrics = self._prepare_log_metrics( + prompts, + completions, + rewards, + sum_rewards, + ) + rewards_info = { + "rewards": sum_rewards, + "log_metrics": log_metrics, + } + return rewards_info + + def _prepare_log_metrics( + self, + prompts: List[str], + completions: List[str], + rewards: np.ndarray, # (num_prompts, num_reward_fns) + sum_rewards: np.ndarray, # (num_prompts,) + ) -> Dict[str, Any]: + """Logs individual and summed rewards, along with prompts/completions, for each trajectory.""" + # Assuming self.reward_fns and self.rl_cluster are accessible instance attributes + metrics_to_log = {} + + # Log prompts and completions. + metrics_to_log["prompts"] = (prompts, None) + metrics_to_log["completions"] = (completions, None) + + # Log the sum rewards for each prompt-completion pair. + metrics_to_log["rewards/sum"] = (sum_rewards, np.mean) + + # Log the min and max rewards for the prompt-completion pair. + metrics_to_log["rewards/min"] = (np.min(rewards, axis=1), np.min) + metrics_to_log["rewards/max"] = (np.max(rewards, axis=1), np.max) + + # Log individual rewards for this trajectory + for i, reward_fn in enumerate(self.reward_fns): + metric_name = f"rewards/{reward_fn.__name__}" + metrics_to_log[metric_name] = (rewards[:, i], np.mean) + + return metrics_to_log diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index df4e09911..f9977595c 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -29,6 +29,8 @@ import numpy as np from tunix.rl import algorithm_config as algo_config_lib from tunix.rl import common +from tunix.rl import function_registry +from tunix.rl import reward from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl import utils as rl_utils from tunix.rl.queue import data_queue as queue_lib @@ -79,9 +81,15 @@ def __init__( """ self.rl_cluster = rl_cluster self.algo_config = algo_config - self.reward_fns = ( - [reward_fns] if not isinstance(reward_fns, Sequence) else reward_fns + + reward_manager_fn = function_registry.get_reward_manager( + algo_config.reward_manager + ) + self.reward_manager = reward_manager_fn( + reward_fns=reward_fns, + algo_config=algo_config, ) + self.metric_fns = metric_fns or [] self.rl_cluster.actor_trainer.is_managed_externally = True if hasattr(self.rl_cluster, "critic_trainer"): @@ -183,59 +191,20 @@ def _compute_rewards( raise ValueError(f"kwargs already contains mode as a key: {kwargs}") kwargs["mode"] = str(mode) - num_prompts = len(prompts) - num_reward_fns = len(self.reward_fns) - rewards = np.zeros((num_prompts, num_reward_fns)) - - # Compute all rewards for each prompt-completion pair. - for i, reward_fn in enumerate(self.reward_fns): - r = reward_fn(prompts=prompts, completions=completions, **kwargs) - - if r is None: - raise RuntimeError( - f"Failed to obtain result from {reward_fn.__name__}. Result is" - " None." - ) - if isinstance(r, list) and len(r) != len(prompts): - raise RuntimeError( - f"Length mismatch after {reward_fn.__name__}: " - f"len(r)={len(r)}, len(prompts)={num_prompts}. " - f"Content of r: {r}" - ) - - rewards[:, i] = np.array(r) - - # Sum rewards across all reward functions for each prompt. - sum_rewards = np.nansum(rewards, axis=1) - - # Log all metrics in a single loop - for j, (prompt, completion) in enumerate(zip(prompts, completions)): - metrics_to_log = {} - - # Log prompts and completions. - metrics_to_log["prompts"] = (prompt, None) - metrics_to_log["completions"] = (completion, None) - - # Log the summed rewards for this trajectory. - trajectory_sum = sum_rewards[j] - metrics_to_log["rewards/sum"] = (trajectory_sum, np.mean) - metrics_to_log["rewards/min"] = (np.min(rewards[j]), np.min) - metrics_to_log["rewards/max"] = (np.max(rewards[j]), np.max) - - # Log individual rewards for this trajectory - for i, reward_fn in enumerate(self.reward_fns): - metric_name = f"rewards/{reward_fn.__name__}" - metrics_to_log[metric_name] = (rewards[j, i], np.mean) + rewards_info = self.reward_manager( + prompts=prompts, + completions=completions, + **kwargs, + ) - # Log all metrics for this trajectory in one call - if step is not None: - self.rl_cluster.buffer_metrics_async( - metrics_to_log, mode=mode, step=step - ) - else: - self.rl_cluster.buffer_metrics(metrics_to_log, mode=mode) + if step is not None: + self.rl_cluster.buffer_metrics_async( + rewards_info["log_metrics"], mode=mode, step=step + ) + else: + self.rl_cluster.buffer_metrics(rewards_info["log_metrics"], mode=mode) - return sum_rewards + return rewards_info["rewards"] def _process_accumulated_batches( self, diff --git a/tunix/tests/test_common.py b/tunix/tests/test_common.py index 88b4f9061..2498d67d0 100644 --- a/tunix/tests/test_common.py +++ b/tunix/tests/test_common.py @@ -41,6 +41,7 @@ def _convert_to_nparray(arr): return np.asarray(arr) return arr + def assert_equal(path, x, y): np.testing.assert_array_equal( _convert_to_nparray(x),