diff --git a/tests/rl/experimental/parity_test.py b/tests/rl/experimental/parity_test.py new file mode 100644 index 00000000..8a80ca62 --- /dev/null +++ b/tests/rl/experimental/parity_test.py @@ -0,0 +1,339 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity test for Agentic GRPO vs GRPO.""" + +import os +import random +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from flax import nnx +import grain.python as grain +import jax +from jax import sharding +from jax.interpreters import pxla +import numpy as np +import optax +from tunix.generate import tokenizer_adapter +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.experimental import agentic_grpo_learner +from tunix.rl.grpo import grpo_learner as grpo_learner_lib +from tunix.rl.rollout import base_rollout +from tunix.tests import test_common as tc + +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' +Mesh = sharding.Mesh + + +class MockChatParser: + + def parse(self, messages, add_generation_prompt=False, is_first_msg=False): + # This mock parser ensures that the chat messages are flattened to a simple + # string that matches what the standard GRPO learner receives (raw prompt). + if not messages: + return '' + # We assume messages are like [{"role": "user", "content": "..."}] + # We just return the content of the last message (user message). + content = messages[-1]['content'] + # Prepend to match the BOS token added by VanillaRollout/Sampler only + # for the first message (which is the System message in Agentic). + if is_first_msg: + return ' ' + content + return content + + @property + def assistant_token(self): + return '' + + +def _dummy_dataset(source, batch_size: int = 1): + return ( + grain.MapDataset.source(source) + .batch(batch_size) + .map(lambda x: {'prompts': x, 'answer': x, 'question': x}) + ) + + +def reward_fn(completions, **kargs): # pylint: disable=unused-argument + return [float(len(c)) for c in completions] + + +class ParityTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + random.seed(42) + np.random.seed(42) + + @parameterized.named_parameters( + dict( + testcase_name='multi_iter_without_gradient_accumulation', + name='multi_iter_without_gradient_accumulation', + num_iterations=2, + beta=0.04, + gradient_accumulation_steps=None, + ), + dict( + testcase_name='multi_iter_with_gradient_accumulation', + name='multi_iter_with_gradient_accumulation', + num_iterations=2, + beta=0.04, + gradient_accumulation_steps=3, + ), + dict( + testcase_name='multi_iter_without_kl', + name='multi_iter_without_kl', + num_iterations=2, + beta=0, + gradient_accumulation_steps=3, + ), + dict( + testcase_name='singler_iter_with_gradient_accumulation', + name='singler_iter_with_gradient_accumulation', + num_iterations=1, + beta=0.04, + gradient_accumulation_steps=3, + ), + dict( + testcase_name='singler_iter_without_gradient_accumulation', + name='singler_iter_without_gradient_accumulation', + num_iterations=1, + beta=0.04, + gradient_accumulation_steps=None, + ), + dict( + testcase_name='singler_iter_without_kl', + name='singler_iter_without_kl', + num_iterations=1, + beta=0, + gradient_accumulation_steps=None, + ), + ) + def test_model_weights_parity( + self, + name, + num_iterations, + beta, + gradient_accumulation_steps, + ): + # TODO: b/446969561 - Re-enable this test case. + if name in ( + 'multi_iter_with_gradient_accumulation', + 'multi_iter_without_kl', + 'singler_iter_with_gradient_accumulation', + ): + self.skipTest( + 'Skipping failing test cases with gradient accumulation > 1. See' + ' b/446969561 for details.' + ) + + vocab = tc.MockVocab() + tokenizer = tokenizer_adapter.TokenizerAdapter(vocab) + + # Ensure tokenizer has apply_chat_template which behaves like MockChatParser + # so Agentic's chat-formatted prompt matches Standard's raw prompt. + if not hasattr(tokenizer, 'apply_chat_template'): + + def mock_apply_chat_template(messages, **kwargs): + return messages[-1]['content'] + + tokenizer.apply_chat_template = mock_apply_chat_template + + # Patch tokenizer.encode to ensure consistent tokenization (e.g. no BOS/EOS) + # between Agentic (which calls encode with add_special_tokens=False) and + # Standard (which uses VanillaRollout defaults). + def mock_encode(text, add_special_tokens=False): + del add_special_tokens # Ignore flag to enforce parity + return vocab.EncodeAsIds(text) + + tokenizer.encode = mock_encode + + model1 = tc.ToyTransformer( + config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()), + rngs=nnx.Rngs(0), + ) + model2 = tc.ToyTransformer( + config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()), + rngs=nnx.Rngs(0), + ) + + # Initialize weights identically + variables1 = nnx.state(model1, nnx.Param) + variables2 = nnx.state(model2, nnx.Param) + jax.tree.map_with_path(tc.assert_close, variables1, variables2) + + ref_model = tc.ToyTransformer( + config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()), + rngs=nnx.Rngs(0), + ) + + # Common Configs + num_generations = 2 + max_steps = 4 + eval_every_n_steps = 12 + max_prompt_length = 256 + max_generation_steps = 10 + + mesh = pxla.thread_resources.env.physical_mesh + cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: mesh, + rl_cluster_lib.Role.REFERENCE: mesh, + rl_cluster_lib.Role.ROLLOUT: mesh, + }, + rollout_engine='vanilla', + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optax.sgd(1e-3), + eval_every_n_steps=eval_every_n_steps, + max_steps=max_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + # Ensure batch sizes match + mini_batch_size=1, + train_micro_batch_size=1, + ), + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=max_generation_steps, + max_prompt_length=max_prompt_length, + kv_cache_size=1024, + temperature=0.0, # Deterministic sampling + ), + ) + + # 1. Setup Standard GRPO Learner + rl_cluster1 = rl_cluster_lib.RLCluster( + actor=model1, + reference=ref_model, + tokenizer=tokenizer, + cluster_config=cluster_config, + ) + + grpo_config1 = grpo_learner_lib.GRPOConfig( + num_generations=num_generations, + num_iterations=num_iterations, + beta=beta, + loss_algo='grpo', + ) + + grpo_learner = grpo_learner_lib.GRPOLearner( + rl_cluster=rl_cluster1, + reward_fns=reward_fn, + algo_config=grpo_config1, + ) + + # 2. Setup Agentic GRPO Learner + rl_cluster2 = rl_cluster_lib.RLCluster( + actor=model2, + reference=ref_model, + tokenizer=tokenizer, + cluster_config=cluster_config, + ) + + grpo_config2 = agentic_grpo_learner.GRPOConfig( + num_generations=num_generations, + num_iterations=num_iterations, + beta=beta, + loss_algo='grpo', + system_prompt='', + max_concurrency=1, + ) + + agentic_learner = agentic_grpo_learner.GRPOLearner( + rl_cluster=rl_cluster2, + reward_fns=reward_fn, + algo_config=grpo_config2, + chat_parser=MockChatParser(), + ) + + # Data + prompts = ['input string', 'hello world', 'My name', 'hello there'] + # Repeat data to ensure enough steps + train_ds = _dummy_dataset(prompts * 2, batch_size=1) + + # Run Training + with mock.patch.object( + rl_cluster1, 'update_actor', wraps=rl_cluster1.update_actor + ) as mock_update1, mock.patch.object( + rl_cluster1, 'sync_weights', wraps=rl_cluster1.sync_weights + ) as mock_sync1, mock.patch.object( + rl_cluster2, 'update_actor', wraps=rl_cluster2.update_actor + ) as mock_update2, mock.patch.object( + rl_cluster2, 'sync_weights', wraps=rl_cluster2.sync_weights + ) as mock_sync2: + grpo_learner.train(train_ds, None) + agentic_learner.train(train_ds, None) + + self.assertEqual( + mock_update1.call_count, + mock_update2.call_count, + msg=( + f'update_actor call count mismatch: {mock_update1.call_count} !=' + f' {mock_update2.call_count}' + ), + ) + self.assertEqual( + mock_sync1.call_count, + mock_sync2.call_count, + msg=( + f'sync_weights call count mismatch: {mock_sync1.call_count} !=' + f' {mock_sync2.call_count}' + ), + ) + + # Verify update_actor arguments (Data Parity) + def get_train_examples(mock_update): + examples = [] + for call in mock_update.call_args_list: + # args[0] is the batch (List[TrainExample]) + batch = call.args[0] + examples.extend(batch) + return examples + + examples1 = get_train_examples(mock_update1) + examples2 = get_train_examples(mock_update2) + + self.assertEqual( + len(examples1), + len(examples2), + msg='Number of training examples passed to update_actor mismatch', + ) + + for i, (ex1, ex2) in enumerate(zip(examples1, examples2)): + np.testing.assert_array_equal( + ex1.prompt_ids, + ex2.prompt_ids, + err_msg=f'prompt_ids mismatch at index {i}', + ) + np.testing.assert_array_equal( + ex1.completion_ids, + ex2.completion_ids, + err_msg=f'completion_ids mismatch at index {i}', + ) + np.testing.assert_allclose( + ex1.advantages, + ex2.advantages, + atol=1e-5, + err_msg=f'advantages mismatch at index {i}', + ) + + # Verify Parity + variables1 = nnx.state(model1, nnx.Param) + variables2 = nnx.state(model2, nnx.Param) + jax.tree.map_with_path(tc.assert_equal, variables1, variables2) + + +if __name__ == '__main__': + absltest.main()