|
1 | 1 | defmodule Nx.Defn.TemplateDiff do
|
2 | 2 | @moduledoc false
|
| 3 | + import Nx, only: [is_tensor: 1] |
3 | 4 | defstruct [:left, :right, :left_title, :right_title, :compatible]
|
4 | 5 |
|
5 |
| - defp is_valid_container?(impl) do |
6 |
| - not is_nil(impl) and impl != Nx.Container.Any |
| 6 | + def build(left, right, left_title, right_title, _compatibility_fn \\ &Nx.compatible?/2) |
| 7 | + |
| 8 | + def build(left, right, left_title, right_title, compatibility_fn) |
| 9 | + when is_tensor(left) and is_tensor(right) do |
| 10 | + %__MODULE__{ |
| 11 | + left: left, |
| 12 | + left_title: left_title, |
| 13 | + right: right, |
| 14 | + right_title: right_title, |
| 15 | + compatible: compatibility_fn.(left, right) |
| 16 | + } |
7 | 17 | end
|
8 | 18 |
|
9 |
| - def build(left, right, left_title, right_title, compatibility_fn \\ &Nx.compatible?/2) do |
| 19 | + def build(left, right, left_title, right_title, compatibility_fn) do |
10 | 20 | left_impl = Nx.Container.impl_for(left)
|
11 | 21 | right_impl = Nx.Container.impl_for(right)
|
12 | 22 |
|
13 |
| - l = is_valid_container?(left_impl) |
14 |
| - r = is_valid_container?(right_impl) |
| 23 | + if left_impl == right_impl and left_impl != nil do |
| 24 | + flatten = right |> Nx.Container.reduce([], &[&1 | &2]) |> Enum.reverse() |
15 | 25 |
|
16 |
| - cond do |
17 |
| - not l and not r -> |
18 |
| - %__MODULE__{ |
19 |
| - left: left, |
20 |
| - left_title: left_title, |
21 |
| - right: right, |
22 |
| - right_title: right_title, |
23 |
| - compatible: left == right |
24 |
| - } |
| 26 | + {diff, acc} = |
| 27 | + Nx.Container.traverse(left, flatten, fn |
| 28 | + left, [] -> |
| 29 | + {%__MODULE__{left: left}, :incompatible_sizes} |
25 | 30 |
|
26 |
| - not l or not r -> |
27 |
| - %__MODULE__{ |
28 |
| - left: left, |
29 |
| - left_title: left_title, |
30 |
| - right: right, |
31 |
| - right_title: right_title, |
32 |
| - compatible: false |
33 |
| - } |
| 31 | + left, [right | acc] -> |
| 32 | + {build(left, right, left_title, right_title, compatibility_fn), acc} |
| 33 | + end) |
34 | 34 |
|
35 |
| - left_impl != right_impl -> |
| 35 | + if acc == [] and compatible_keys?(left_impl, left, right) do |
| 36 | + diff |
| 37 | + else |
36 | 38 | %__MODULE__{
|
37 | 39 | left: left,
|
38 | 40 | left_title: left_title,
|
39 | 41 | right: right,
|
40 | 42 | right_title: right_title,
|
41 | 43 | compatible: false
|
42 | 44 | }
|
43 |
| - |
44 |
| - l and r -> |
45 |
| - {diff, acc} = |
46 |
| - Nx.Defn.Composite.traverse(left, Nx.Defn.Composite.flatten_list([right]), fn |
47 |
| - left, [] -> |
48 |
| - {%__MODULE__{left: left}, :incompatible_sizes} |
49 |
| - |
50 |
| - left, [right | acc] -> |
51 |
| - { |
52 |
| - %__MODULE__{ |
53 |
| - left: left, |
54 |
| - right: right, |
55 |
| - left_title: left_title, |
56 |
| - right_title: right_title, |
57 |
| - compatible: compatibility_fn.(left, right) |
58 |
| - }, |
59 |
| - acc |
60 |
| - } |
61 |
| - end) |
62 |
| - |
63 |
| - if acc == :incompatible_sizes do |
64 |
| - %__MODULE__{ |
65 |
| - left: left, |
66 |
| - left_title: left_title, |
67 |
| - right: right, |
68 |
| - right_title: right_title, |
69 |
| - compatible: false |
70 |
| - } |
71 |
| - else |
72 |
| - diff |
73 |
| - end |
| 45 | + end |
| 46 | + else |
| 47 | + %__MODULE__{ |
| 48 | + left: left, |
| 49 | + left_title: left_title, |
| 50 | + right: right, |
| 51 | + right_title: right_title, |
| 52 | + compatible: false |
| 53 | + } |
74 | 54 | end
|
75 | 55 | end
|
76 | 56 |
|
| 57 | + defp compatible_keys?(Nx.Container.Map, left, right), |
| 58 | + do: Enum.all?(Map.keys(left), &is_map_key(right, &1)) |
| 59 | + |
| 60 | + defp compatible_keys?(_, _, _), |
| 61 | + do: true |
| 62 | + |
77 | 63 | def build_and_inspect(
|
78 | 64 | left,
|
79 | 65 | right,
|
@@ -145,8 +131,7 @@ defmodule Nx.Defn.TemplateDiff do
|
145 | 131 | end
|
146 | 132 |
|
147 | 133 | defp inspect_as_template(data, opts) do
|
148 |
| - if is_number(data) or is_tuple(data) or |
149 |
| - (is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do |
| 134 | + if Nx.Container.impl_for(data) != nil do |
150 | 135 | data
|
151 | 136 | |> Nx.to_template()
|
152 | 137 | |> to_doc(
|
|
0 commit comments