Skip to content

Commit 3b4b2a1

Browse files
committed
Improve coverage on inspect diffing and fix bugs, closes #1349
1 parent c8fd9b4 commit 3b4b2a1

File tree

2 files changed

+120
-56
lines changed

2 files changed

+120
-56
lines changed

nx/lib/nx/defn/template_diff.ex

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,65 @@
11
defmodule Nx.Defn.TemplateDiff do
22
@moduledoc false
3+
import Nx, only: [is_tensor: 1]
34
defstruct [:left, :right, :left_title, :right_title, :compatible]
45

5-
defp is_valid_container?(impl) do
6-
not is_nil(impl) and impl != Nx.Container.Any
6+
def build(left, right, left_title, right_title, _compatibility_fn \\ &Nx.compatible?/2)
7+
8+
def build(left, right, left_title, right_title, compatibility_fn)
9+
when is_tensor(left) and is_tensor(right) do
10+
%__MODULE__{
11+
left: left,
12+
left_title: left_title,
13+
right: right,
14+
right_title: right_title,
15+
compatible: compatibility_fn.(left, right)
16+
}
717
end
818

9-
def build(left, right, left_title, right_title, compatibility_fn \\ &Nx.compatible?/2) do
19+
def build(left, right, left_title, right_title, compatibility_fn) do
1020
left_impl = Nx.Container.impl_for(left)
1121
right_impl = Nx.Container.impl_for(right)
1222

13-
l = is_valid_container?(left_impl)
14-
r = is_valid_container?(right_impl)
23+
if left_impl == right_impl and left_impl != nil do
24+
flatten = right |> Nx.Container.reduce([], &[&1 | &2]) |> Enum.reverse()
1525

16-
cond do
17-
not l and not r ->
18-
%__MODULE__{
19-
left: left,
20-
left_title: left_title,
21-
right: right,
22-
right_title: right_title,
23-
compatible: left == right
24-
}
26+
{diff, acc} =
27+
Nx.Container.traverse(left, flatten, fn
28+
left, [] ->
29+
{%__MODULE__{left: left}, :incompatible_sizes}
2530

26-
not l or not r ->
27-
%__MODULE__{
28-
left: left,
29-
left_title: left_title,
30-
right: right,
31-
right_title: right_title,
32-
compatible: false
33-
}
31+
left, [right | acc] ->
32+
{build(left, right, left_title, right_title, compatibility_fn), acc}
33+
end)
3434

35-
left_impl != right_impl ->
35+
if acc == [] and compatible_keys?(left_impl, left, right) do
36+
diff
37+
else
3638
%__MODULE__{
3739
left: left,
3840
left_title: left_title,
3941
right: right,
4042
right_title: right_title,
4143
compatible: false
4244
}
43-
44-
l and r ->
45-
{diff, acc} =
46-
Nx.Defn.Composite.traverse(left, Nx.Defn.Composite.flatten_list([right]), fn
47-
left, [] ->
48-
{%__MODULE__{left: left}, :incompatible_sizes}
49-
50-
left, [right | acc] ->
51-
{
52-
%__MODULE__{
53-
left: left,
54-
right: right,
55-
left_title: left_title,
56-
right_title: right_title,
57-
compatible: compatibility_fn.(left, right)
58-
},
59-
acc
60-
}
61-
end)
62-
63-
if acc == :incompatible_sizes do
64-
%__MODULE__{
65-
left: left,
66-
left_title: left_title,
67-
right: right,
68-
right_title: right_title,
69-
compatible: false
70-
}
71-
else
72-
diff
73-
end
45+
end
46+
else
47+
%__MODULE__{
48+
left: left,
49+
left_title: left_title,
50+
right: right,
51+
right_title: right_title,
52+
compatible: false
53+
}
7454
end
7555
end
7656

57+
defp compatible_keys?(Nx.Container.Map, left, right),
58+
do: Enum.all?(Map.keys(left), &is_map_key(right, &1))
59+
60+
defp compatible_keys?(_, _, _),
61+
do: true
62+
7763
def build_and_inspect(
7864
left,
7965
right,
@@ -145,8 +131,7 @@ defmodule Nx.Defn.TemplateDiff do
145131
end
146132

147133
defp inspect_as_template(data, opts) do
148-
if is_number(data) or is_tuple(data) or
149-
(is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do
134+
if Nx.Container.impl_for(data) != nil do
150135
data
151136
|> Nx.to_template()
152137
|> to_doc(
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
defmodule Nx.Defn.TemplateDiffTest do
2+
use ExUnit.Case, async: true
3+
4+
defp build(left, right) do
5+
Nx.Defn.TemplateDiff.build(left, right, "left", "right", &Nx.compatible?/2)
6+
end
7+
8+
test "compatible" do
9+
assert build(1, 2) == %Nx.Defn.TemplateDiff{
10+
left: 1,
11+
right: 2,
12+
compatible: true,
13+
left_title: "left",
14+
right_title: "right"
15+
}
16+
end
17+
18+
test "incompatible" do
19+
assert build(1, Nx.tensor([2])) == %Nx.Defn.TemplateDiff{
20+
left: 1,
21+
right: Nx.tensor([2]),
22+
compatible: false,
23+
left_title: "left",
24+
right_title: "right"
25+
}
26+
end
27+
28+
test "implementation incompatible" do
29+
assert build(%{foo: 1}, Nx.tensor([2])) == %Nx.Defn.TemplateDiff{
30+
left: %{foo: 1},
31+
right: Nx.tensor([2]),
32+
compatible: false,
33+
left_title: "left",
34+
right_title: "right"
35+
}
36+
end
37+
38+
test "size incompatible" do
39+
assert build(%{foo: 1}, %{}) == %Nx.Defn.TemplateDiff{
40+
left: %{foo: 1},
41+
right: %{},
42+
compatible: false,
43+
left_title: "left",
44+
right_title: "right"
45+
}
46+
47+
assert build(%{}, %{foo: 1}) == %Nx.Defn.TemplateDiff{
48+
left: %{},
49+
right: %{foo: 1},
50+
compatible: false,
51+
left_title: "left",
52+
right_title: "right"
53+
}
54+
end
55+
56+
test "child incompatible" do
57+
assert build(%{foo: 1}, %{foo: Nx.tensor([2])}) ==
58+
%{
59+
foo: %Nx.Defn.TemplateDiff{
60+
left: 1,
61+
right: Nx.tensor([2]),
62+
compatible: false,
63+
left_title: "left",
64+
right_title: "right"
65+
}
66+
}
67+
end
68+
69+
test "keys incompatible" do
70+
assert build(%{foo: 1}, %{bar: 1}) ==
71+
%Nx.Defn.TemplateDiff{
72+
left: %{foo: 1},
73+
right: %{bar: 1},
74+
compatible: false,
75+
left_title: "left",
76+
right_title: "right"
77+
}
78+
end
79+
end

0 commit comments

Comments
 (0)