Skip to content

Commit 4377a57

Browse files
zhuyuegongchensu
authored andcommitted
Issue/685 - Refactor random_sample test to use result-based comparison logic.
1 parent 5c1cb64 commit 4377a57

File tree

1 file changed

+83
-88
lines changed

1 file changed

+83
-88
lines changed

test/infinicore/ops/random_sample.py

Lines changed: 83 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -206,96 +206,91 @@ def run_test(self, device, test_case, config):
206206
# Clear stored logits before test to ensure fresh generation
207207
self._current_logits = None
208208

209-
try:
210-
# Try the standard comparison first
211-
# This will call prepare_pytorch_inputs_and_kwargs which will set self._current_logits
212-
return super().run_test(device, test_case, config)
213-
except AssertionError as original_error:
214-
# If standard comparison fails, check if this is a valid case where
215-
# indices differ but logits values are equal
216-
217-
# Only handle if we have stored logits (from prepare_pytorch_inputs_and_kwargs)
218-
if self._current_logits is None:
219-
raise
220-
221-
logits_tensor = self._current_logits
222-
223-
# Re-run operations with the same logits to get results for comparison
224-
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
225-
from framework.base import TestResult
226-
from framework.utils import (
227-
convert_infinicore_to_torch,
228-
infinicore_tensor_from_torch,
229-
)
230-
231-
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
232-
233-
# Prepare infinicore inputs
234-
infini_inputs = []
235-
for inp in inputs:
236-
if isinstance(inp, torch.Tensor):
237-
cloned_inp = inp.clone().detach()
238-
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
239-
infini_inputs.append(infini_tensor)
240-
else:
241-
infini_inputs.append(inp)
242-
243-
infini_kwargs = kwargs.copy()
244-
if "out" in infini_kwargs and isinstance(
245-
infini_kwargs["out"], torch.Tensor
246-
):
247-
cloned_out = infini_kwargs["out"].clone().detach()
248-
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
249-
250-
# Run both operators
251-
torch_result = self.torch_operator(*inputs, **kwargs)
252-
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
253-
254-
# Extract indices from results
255-
comparison_target = test_case.comparison_target
256-
if comparison_target == "out":
257-
# Compare output tensor from kwargs
258-
ref_idx = kwargs["out"].item()
259-
torch_result_from_infini = convert_infinicore_to_torch(
260-
infini_kwargs["out"]
261-
)
262-
ic_idx = torch_result_from_infini.item()
263-
else:
264-
# Compare return values
265-
ref_idx = torch_result.item()
266-
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
267-
ic_idx = torch_result_from_infini.item()
268-
269-
# Check if indices are equal (standard case)
270-
if ic_idx == ref_idx:
271-
# Return a successful TestResult object
272-
return TestResult(
273-
success=True,
274-
return_code=0,
275-
test_case=test_case,
276-
device=device,
277-
)
278-
279-
# Special case: indices differ but logits values are equal
280-
# This is valid for random_sample when multiple indices have the same logits value
281-
try:
282-
logits_ref = logits_tensor[ref_idx].item()
283-
logits_ic = logits_tensor[ic_idx].item()
284-
if logits_ic == logits_ref:
285-
# Valid: different indices but same logits value
286-
# Return a successful TestResult object
287-
return TestResult(
288-
success=True,
289-
return_code=0,
290-
test_case=test_case,
291-
device=device,
209+
# Call parent's run_test, but intercept the result to check for failures
210+
result = super().run_test(device, test_case, config)
211+
212+
# Check if test failed and try special comparison logic
213+
if not result.success and "Result comparison failed" in result.error_message:
214+
# Try special comparison logic for random_sample
215+
# When indices differ but logits values are equal, the result is still valid
216+
if self._current_logits is not None:
217+
try:
218+
# Re-run operations with the same logits to get results for comparison
219+
from framework.base import TestResult
220+
from framework.utils import (
221+
convert_infinicore_to_torch,
222+
infinicore_tensor_from_torch,
292223
)
293-
except (IndexError, RuntimeError):
294-
# If we can't access the logits, fall through to raise the original error
295-
pass
296224

297-
# If we get here, the results are truly different
298-
raise original_error
225+
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
226+
227+
# Prepare infinicore inputs
228+
infini_inputs = []
229+
for inp in inputs:
230+
if isinstance(inp, torch.Tensor):
231+
cloned_inp = inp.clone().detach()
232+
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
233+
infini_inputs.append(infini_tensor)
234+
else:
235+
infini_inputs.append(inp)
236+
237+
infini_kwargs = kwargs.copy()
238+
if "out" in infini_kwargs and isinstance(
239+
infini_kwargs["out"], torch.Tensor
240+
):
241+
cloned_out = infini_kwargs["out"].clone().detach()
242+
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
243+
244+
# Run both operators
245+
torch_result = self.torch_operator(*inputs, **kwargs)
246+
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
247+
248+
# Extract indices from results
249+
comparison_target = test_case.comparison_target
250+
if comparison_target == "out":
251+
ref_idx = kwargs["out"].item()
252+
torch_result_from_infini = convert_infinicore_to_torch(
253+
infini_kwargs["out"]
254+
)
255+
ic_idx = torch_result_from_infini.item()
256+
else:
257+
ref_idx = torch_result.item()
258+
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
259+
ic_idx = torch_result_from_infini.item()
260+
261+
# Check if indices are equal (standard case)
262+
if ic_idx == ref_idx:
263+
return TestResult(
264+
success=True,
265+
return_code=0,
266+
test_case=test_case,
267+
device=device,
268+
)
269+
270+
# Special case: indices differ but logits values are equal
271+
# Match infiniop test logic: check if logits values are equal
272+
logits_tensor = self._current_logits
273+
logits_ref_val = logits_tensor[ref_idx]
274+
logits_ic_val = logits_tensor[ic_idx]
275+
276+
# For bfloat16/float16, convert to float32 for comparison (same as infiniop debug function)
277+
if logits_tensor.dtype in (torch.bfloat16, torch.float16):
278+
logits_ref_val = logits_ref_val.float()
279+
logits_ic_val = logits_ic_val.float()
280+
281+
if torch.equal(logits_ref_val, logits_ic_val):
282+
# Valid: different indices but same logits value
283+
return TestResult(
284+
success=True,
285+
return_code=0,
286+
test_case=test_case,
287+
device=device,
288+
)
289+
except Exception:
290+
# If special comparison fails, fall through to return original result
291+
pass
292+
293+
return result
299294

300295

301296
def main():

0 commit comments

Comments
 (0)