Skip to content

Commit 3b1c7fd

Browse files
authored
Add numerical comparator base class and L1 comparator
Differential Revision: D76746854 Pull Request resolved: #11751
1 parent 5365c55 commit 3b1c7fd

File tree

7 files changed

+191
-0
lines changed

7 files changed

+191
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
6+
python_library(
7+
name = "numerical_comparator_base",
8+
srcs = ["numerical_comparator_base.py"],
9+
deps = [],
10+
)
11+
12+
13+
python_library(
14+
name = "l1_numerical_comparator",
15+
srcs = ["l1_numerical_comparator.py"],
16+
deps = [
17+
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
18+
"//executorch/devtools/inspector:lib",
19+
],
20+
)
21+
22+
23+
24+
python_library(
25+
name = "lib",
26+
srcs = ["__init__.py"],
27+
deps = [
28+
":l1_numerical_comparator",
29+
],
30+
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from executorch.devtools.inspector.numerical_comparator.l1_numerical_comparator import (
9+
L1Comparator,
10+
)
11+
12+
13+
__all__ = ["L1Comparator"]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
from abc import ABC, abstractmethod
11+
from typing import Any
12+
13+
14+
class InspectorNumericalComparatorBase(ABC):
15+
@abstractmethod
16+
def compare(self, a: Any, b: Any) -> float:
17+
"""Compare two intermediate output and return a result.
18+
19+
This method should be overridden by subclasses to provide custom comparison logic.
20+
21+
Args:
22+
a: The first intermediate output to compare.
23+
b: The second intermediate output to compare.
24+
25+
Returns:
26+
A numerical result indicating the comparison outcome.
27+
"""
28+
pass
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any
8+
9+
import torch
10+
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
11+
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
12+
NumericalComparatorBase,
13+
)
14+
15+
16+
class L1Comparator(NumericalComparatorBase):
17+
def compare(self, a: Any, b: Any) -> float:
18+
"""Sum up all these element-wise absolute differences between two tensors."""
19+
20+
t_a = convert_to_float_tensor(a)
21+
t_b = convert_to_float_tensor(b)
22+
if torch.isnan(t_a).any() or torch.isnan(t_b).any():
23+
t_a = torch.nan_to_num(t_a)
24+
t_b = torch.nan_to_num(t_b)
25+
26+
try:
27+
res = torch.abs(t_a - t_b).sum().item()
28+
except Exception as e:
29+
raise ValueError(f"Error computing L1 difference between tensors: {str(e)}")
30+
return res
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from abc import ABC, abstractmethod
9+
from typing import Any
10+
11+
12+
class NumericalComparatorBase(ABC):
13+
@abstractmethod
14+
def compare(self, a: Any, b: Any) -> float:
15+
"""Compare two intermediate output and return a result.
16+
17+
This method should be overridden by subclasses to provide custom comparison logic.
18+
19+
Args:
20+
a: The first intermediate output to compare.
21+
b: The second intermediate output to compare.
22+
23+
Returns:
24+
A numerical result indicating the comparison outcome.
25+
"""
26+
pass

devtools/inspector/tests/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ python_unittest(
5454
],
5555
)
5656

57+
python_unittest(
58+
name = "l1_comparator_test",
59+
srcs = ["l1_comparator_test.py"],
60+
deps = [
61+
"//executorch/devtools/inspector/numerical_comparator:lib",
62+
],
63+
)
64+
5765
python_library(
5866
name = "inspector_test_utils",
5967
srcs = [
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.devtools.inspector.numerical_comparator import L1Comparator
12+
13+
14+
class TestL1Comparator(unittest.TestCase):
15+
l1_comparator = L1Comparator()
16+
17+
def test_identical_tensors(self):
18+
a = torch.tensor([[1, 2], [3, 4]])
19+
b = torch.tensor([[1, 2], [3, 4]])
20+
expected = 0.0
21+
result = self.l1_comparator.compare(a, b)
22+
self.assertAlmostEqual(result, expected)
23+
24+
def test_scalar(self):
25+
a = 1
26+
b = 2
27+
expected = 1.0
28+
result = self.l1_comparator.compare(a, b)
29+
self.assertAlmostEqual(result, expected)
30+
31+
def test_with_nans_replaced_with_zero(self):
32+
a = torch.tensor([3, 2, -1, float("nan")])
33+
b = torch.tensor([float("nan"), 0, -3, 1])
34+
expected = 8.0
35+
result = self.l1_comparator.compare(a, b)
36+
self.assertAlmostEqual(result, expected)
37+
38+
def test_shape_mismatch_raises_exception(self):
39+
a = torch.tensor([0, 2, -1])
40+
b = torch.tensor([1, 0, -3, 4])
41+
with self.assertRaises(ValueError):
42+
self.l1_comparator.compare(a, b)
43+
44+
def test_2D_tensors(self):
45+
a = torch.tensor([[4, 9], [6, 4]])
46+
b = torch.tensor([[1, 2], [3, 5]])
47+
expected = 14.0
48+
result = self.l1_comparator.compare(a, b)
49+
self.assertAlmostEqual(result, expected)
50+
51+
def test_list_of_tensors(self):
52+
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
53+
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
54+
expected = 8.0
55+
result = self.l1_comparator.compare(a, b)
56+
self.assertAlmostEqual(result, expected)

0 commit comments

Comments
 (0)