@@ -101,16 +101,203 @@ def test_diff_loss(self):
101101 grpo_loss .item (),
102102 msg = (
103103 "DAPO and GRPO loss values should be different for the same input"
104- " due to different configurations and potentially different"
105- " logic."
104+ " due to different loss aggregation logics."
106105 ),
107106 )
108107
109108 self .assertIn ("kl" , dapo_aux )
110109 self .assertIn ("kl" , grpo_aux )
111- self .assertNotEqual (
112- dapo_aux ["kl" ], grpo_aux ["kl" ]
113- ) # Expected as beta differs
110+ self .assertEqual (dapo_aux ["kl" ], 0.0 ) # DAPO does not have KL term.
111+
112+
113+ class TestDAPOConfigPostInit (parameterized .TestCase ):
114+
115+ def test_valid_default (self ):
116+ """Tests that default values pass validation."""
117+ try :
118+ dapo_lib .DAPOConfig ()
119+ except ValueError as e :
120+ self .fail (f"DAPOConfig raised ValueError on default initialization: { e } " )
121+
122+ @parameterized .named_parameters (
123+ dict (testcase_name = "custom_epsilons" , epsilon = 0.1 , epsilon_high = 0.15 ),
124+ dict (testcase_name = "epsilons_equal" , epsilon = 0.1 , epsilon_high = 0.1 ),
125+ dict (
126+ testcase_name = "buffer_disabled" ,
127+ overlong_buffer = {"enable" : False },
128+ ),
129+ dict (testcase_name = "buffer_none" , overlong_buffer = None ),
130+ dict (
131+ testcase_name = "valid_buffer" ,
132+ overlong_buffer = {
133+ "enable" : True ,
134+ "overlong_buffer_length" : 2000 ,
135+ "overlong_buffer_penalty" : 0.5 ,
136+ "max_response_length" : 10000 ,
137+ },
138+ ),
139+ )
140+ def test_valid_configurations (self , ** kwargs ):
141+ """Tests various valid custom configurations."""
142+ try :
143+ dapo_lib .DAPOConfig (** kwargs )
144+ except ValueError as e :
145+ self .fail (f"DAPOConfig raised ValueError for valid case { kwargs } : { e } " )
146+
147+ @parameterized .named_parameters (
148+ dict (
149+ testcase_name = "invalid_epsilon_high" ,
150+ config_kwargs = dict (epsilon = 0.2 , epsilon_high = 0.1 ),
151+ expected_regex = (
152+ "epsilon_high must be greater than or equal to epsilon."
153+ ),
154+ ),
155+ dict (
156+ testcase_name = "buffer_missing_length" ,
157+ config_kwargs = dict (
158+ overlong_buffer = {
159+ "enable" : True ,
160+ "overlong_buffer_penalty" : 1.0 ,
161+ "max_response_length" : 20480 ,
162+ }
163+ ),
164+ expected_regex = (
165+ "overlong_buffer is enabled but missing.*overlong_buffer_length.*"
166+ ),
167+ ),
168+ dict (
169+ testcase_name = "buffer_missing_penalty" ,
170+ config_kwargs = dict (
171+ overlong_buffer = {
172+ "enable" : True ,
173+ "overlong_buffer_length" : 4096 ,
174+ "max_response_length" : 20480 ,
175+ }
176+ ),
177+ expected_regex = (
178+ "overlong_buffer is enabled but missing"
179+ ".*overlong_buffer_penalty.*"
180+ ),
181+ ),
182+ dict (
183+ testcase_name = "buffer_missing_max_length" ,
184+ config_kwargs = dict (
185+ overlong_buffer = {
186+ "enable" : True ,
187+ "overlong_buffer_length" : 4096 ,
188+ "overlong_buffer_penalty" : 1.0 ,
189+ }
190+ ),
191+ expected_regex = (
192+ "overlong_buffer is enabled but missing.*max_response_length.*"
193+ ),
194+ ),
195+ dict (
196+ testcase_name = "buffer_length_is_none" ,
197+ config_kwargs = dict (
198+ overlong_buffer = {
199+ "enable" : True ,
200+ "overlong_buffer_length" : None ,
201+ "overlong_buffer_penalty" : 1.0 ,
202+ "max_response_length" : 20480 ,
203+ }
204+ ),
205+ expected_regex = (
206+ "overlong_buffer is enabled but missing.*overlong_buffer_length.*"
207+ ),
208+ ),
209+ dict (
210+ testcase_name = "negative_penalty" ,
211+ config_kwargs = dict (
212+ overlong_buffer = {
213+ "enable" : True ,
214+ "overlong_buffer_length" : 4096 ,
215+ "overlong_buffer_penalty" : - 0.5 ,
216+ "max_response_length" : 20480 ,
217+ }
218+ ),
219+ expected_regex = "overlong_buffer_penalty must be non-negative" ,
220+ ),
221+ )
222+ def test_invalid_configurations (self , config_kwargs , expected_regex ):
223+ """Tests various invalid configurations that should raise ValueError."""
224+ with self .assertRaisesRegex (ValueError , expected_regex ):
225+ dapo_lib .DAPOConfig (** config_kwargs )
226+
227+
228+ class RewardShapingTest (parameterized .TestCase ):
229+
230+ def setUp (self ):
231+ super ().setUp ()
232+ self .mock_cluster = mock .MagicMock ()
233+
234+ def test_raises_error_on_none_buffer (self ):
235+ with self .assertRaisesRegex (
236+ ValueError , "reward_shaping is called but with empty overlong_buffer."
237+ ):
238+
239+ dapo_lib .reward_shaping (
240+ prompts = ["test prompt" ],
241+ completions = ["test completion" ],
242+ mode = self .mock_cluster .Mode ,
243+ overlong_buffer = None ,
244+ )
245+
246+ @parameterized .named_parameters (
247+ dict (
248+ testcase_name = "under_length" ,
249+ lengths = [70 ],
250+ expected_scores = [0.0 ],
251+ ),
252+ dict (
253+ testcase_name = "at_expected_length" ,
254+ lengths = [80 ],
255+ expected_scores = [0.0 ],
256+ ),
257+ dict (
258+ testcase_name = "in_buffer_zone" ,
259+ lengths = [90 ],
260+ expected_scores = [- 5.0 ],
261+ ),
262+ dict (
263+ testcase_name = "at_max_length" ,
264+ lengths = [100 ],
265+ expected_scores = [- 10.0 ],
266+ ),
267+ dict (
268+ testcase_name = "over_max_length" ,
269+ lengths = [110 ],
270+ expected_scores = [- 15.0 ],
271+ ),
272+ dict (
273+ testcase_name = "mixed_lengths" ,
274+ lengths = [70 , 80 , 90 , 100 , 110 ],
275+ expected_scores = [0.0 , 0.0 , - 5.0 , - 10.0 , - 15.0 ],
276+ ),
277+ dict (
278+ testcase_name = "zero_penalty" ,
279+ lengths = [110 ],
280+ expected_scores = [0.0 ],
281+ penalty = 0 ,
282+ ),
283+ )
284+ def test_reward_scores (self , lengths , expected_scores , penalty = 10 ):
285+ completions = ["a" * length for length in lengths ]
286+ overlong_buffer = {
287+ "overlong_buffer_length" : 20 ,
288+ "overlong_buffer_penalty" : penalty ,
289+ "max_response_length" : 100 ,
290+ }
291+ # expected_response_length = 100 - 20 = 80
292+
293+ scores = dapo_lib .reward_shaping (
294+ prompts = ["" ] * len (completions ),
295+ completions = completions ,
296+ mode = self .mock_cluster .Mode ,
297+ overlong_buffer = overlong_buffer ,
298+ )
299+
300+ self .assertSequenceAlmostEqual (expected_scores , scores , places = 4 )
114301
115302
116303if __name__ == "__main__" :
0 commit comments