@@ -206,96 +206,91 @@ def run_test(self, device, test_case, config):
206206 # Clear stored logits before test to ensure fresh generation
207207 self ._current_logits = None
208208
209- try :
210- # Try the standard comparison first
211- # This will call prepare_pytorch_inputs_and_kwargs which will set self._current_logits
212- return super ().run_test (device , test_case , config )
213- except AssertionError as original_error :
214- # If standard comparison fails, check if this is a valid case where
215- # indices differ but logits values are equal
216-
217- # Only handle if we have stored logits (from prepare_pytorch_inputs_and_kwargs)
218- if self ._current_logits is None :
219- raise
220-
221- logits_tensor = self ._current_logits
222-
223- # Re-run operations with the same logits to get results for comparison
224- # prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
225- from framework .base import TestResult
226- from framework .utils import (
227- convert_infinicore_to_torch ,
228- infinicore_tensor_from_torch ,
229- )
230-
231- inputs , kwargs = self .prepare_pytorch_inputs_and_kwargs (test_case , device )
232-
233- # Prepare infinicore inputs
234- infini_inputs = []
235- for inp in inputs :
236- if isinstance (inp , torch .Tensor ):
237- cloned_inp = inp .clone ().detach ()
238- infini_tensor = infinicore_tensor_from_torch (cloned_inp )
239- infini_inputs .append (infini_tensor )
240- else :
241- infini_inputs .append (inp )
242-
243- infini_kwargs = kwargs .copy ()
244- if "out" in infini_kwargs and isinstance (
245- infini_kwargs ["out" ], torch .Tensor
246- ):
247- cloned_out = infini_kwargs ["out" ].clone ().detach ()
248- infini_kwargs ["out" ] = infinicore_tensor_from_torch (cloned_out )
249-
250- # Run both operators
251- torch_result = self .torch_operator (* inputs , ** kwargs )
252- infini_result = self .infinicore_operator (* infini_inputs , ** infini_kwargs )
253-
254- # Extract indices from results
255- comparison_target = test_case .comparison_target
256- if comparison_target == "out" :
257- # Compare output tensor from kwargs
258- ref_idx = kwargs ["out" ].item ()
259- torch_result_from_infini = convert_infinicore_to_torch (
260- infini_kwargs ["out" ]
261- )
262- ic_idx = torch_result_from_infini .item ()
263- else :
264- # Compare return values
265- ref_idx = torch_result .item ()
266- torch_result_from_infini = convert_infinicore_to_torch (infini_result )
267- ic_idx = torch_result_from_infini .item ()
268-
269- # Check if indices are equal (standard case)
270- if ic_idx == ref_idx :
271- # Return a successful TestResult object
272- return TestResult (
273- success = True ,
274- return_code = 0 ,
275- test_case = test_case ,
276- device = device ,
277- )
278-
279- # Special case: indices differ but logits values are equal
280- # This is valid for random_sample when multiple indices have the same logits value
281- try :
282- logits_ref = logits_tensor [ref_idx ].item ()
283- logits_ic = logits_tensor [ic_idx ].item ()
284- if logits_ic == logits_ref :
285- # Valid: different indices but same logits value
286- # Return a successful TestResult object
287- return TestResult (
288- success = True ,
289- return_code = 0 ,
290- test_case = test_case ,
291- device = device ,
209+ # Call parent's run_test, but intercept the result to check for failures
210+ result = super ().run_test (device , test_case , config )
211+
212+ # Check if test failed and try special comparison logic
213+ if not result .success and "Result comparison failed" in result .error_message :
214+ # Try special comparison logic for random_sample
215+ # When indices differ but logits values are equal, the result is still valid
216+ if self ._current_logits is not None :
217+ try :
218+ # Re-run operations with the same logits to get results for comparison
219+ from framework .base import TestResult
220+ from framework .utils import (
221+ convert_infinicore_to_torch ,
222+ infinicore_tensor_from_torch ,
292223 )
293- except (IndexError , RuntimeError ):
294- # If we can't access the logits, fall through to raise the original error
295- pass
296224
297- # If we get here, the results are truly different
298- raise original_error
225+ inputs , kwargs = self .prepare_pytorch_inputs_and_kwargs (test_case , device )
226+
227+ # Prepare infinicore inputs
228+ infini_inputs = []
229+ for inp in inputs :
230+ if isinstance (inp , torch .Tensor ):
231+ cloned_inp = inp .clone ().detach ()
232+ infini_tensor = infinicore_tensor_from_torch (cloned_inp )
233+ infini_inputs .append (infini_tensor )
234+ else :
235+ infini_inputs .append (inp )
236+
237+ infini_kwargs = kwargs .copy ()
238+ if "out" in infini_kwargs and isinstance (
239+ infini_kwargs ["out" ], torch .Tensor
240+ ):
241+ cloned_out = infini_kwargs ["out" ].clone ().detach ()
242+ infini_kwargs ["out" ] = infinicore_tensor_from_torch (cloned_out )
243+
244+ # Run both operators
245+ torch_result = self .torch_operator (* inputs , ** kwargs )
246+ infini_result = self .infinicore_operator (* infini_inputs , ** infini_kwargs )
247+
248+ # Extract indices from results
249+ comparison_target = test_case .comparison_target
250+ if comparison_target == "out" :
251+ ref_idx = kwargs ["out" ].item ()
252+ torch_result_from_infini = convert_infinicore_to_torch (
253+ infini_kwargs ["out" ]
254+ )
255+ ic_idx = torch_result_from_infini .item ()
256+ else :
257+ ref_idx = torch_result .item ()
258+ torch_result_from_infini = convert_infinicore_to_torch (infini_result )
259+ ic_idx = torch_result_from_infini .item ()
260+
261+ # Check if indices are equal (standard case)
262+ if ic_idx == ref_idx :
263+ return TestResult (
264+ success = True ,
265+ return_code = 0 ,
266+ test_case = test_case ,
267+ device = device ,
268+ )
269+
270+ # Special case: indices differ but logits values are equal
271+ # Match infiniop test logic: check if logits values are equal
272+ logits_tensor = self ._current_logits
273+ logits_ref_val = logits_tensor [ref_idx ]
274+ logits_ic_val = logits_tensor [ic_idx ]
275+
276+ # For bfloat16/float16, convert to float32 for comparison (same as infiniop debug function)
277+ if logits_tensor .dtype in (torch .bfloat16 , torch .float16 ):
278+ logits_ref_val = logits_ref_val .float ()
279+ logits_ic_val = logits_ic_val .float ()
280+
281+ if torch .equal (logits_ref_val , logits_ic_val ):
282+ # Valid: different indices but same logits value
283+ return TestResult (
284+ success = True ,
285+ return_code = 0 ,
286+ test_case = test_case ,
287+ device = device ,
288+ )
289+ except Exception :
290+ # If special comparison fails, fall through to return original result
291+ pass
292+
293+ return result
299294
300295
301296def main ():
0 commit comments