Skip to content

Commit

Permalink
fix: complex equality should work between complex and number (#1571)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Jan 23, 2025
1 parent b01e94a commit 32bf7e8
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
22 changes: 20 additions & 2 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,29 @@ defmodule Nx.BinaryBackend do

defp element_equal(_, :nan, _), do: 0
defp element_equal(_, _, :nan), do: 0
defp element_equal(_, a, b), do: boolean_as_number(a == b)

defp element_equal(_, a, b) do
bool =
case {a, b} do
{%Complex{re: re_a, im: im_a}, b} when is_number(b) ->
re_a == b and im_a == 0

{a, %Complex{re: re_b, im: im_b}} when is_number(a) ->
a == re_b and im_b == 0

{a, b} ->
a == b
end

boolean_as_number(bool)
end

defp element_not_equal(_, :nan, _), do: 1
defp element_not_equal(_, _, :nan), do: 1
defp element_not_equal(_, a, b), do: boolean_as_number(a != b)

defp element_not_equal(out, a, b) do
1 - element_equal(out, a, b)
end

defp element_logical_and(_, a, b), do: boolean_as_number(as_boolean(a) and as_boolean(b))
defp element_logical_or(_, a, b), do: boolean_as_number(as_boolean(a) or as_boolean(b))
Expand Down
54 changes: 54 additions & 0 deletions nx/test/nx/complex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,60 @@ defmodule Nx.ComplexTest do
end
end
end

test "equal" do
one_r = 1
one_u8 = Nx.tensor(1, type: {:u, 8})
zero_r = 0
zero_u8 = Nx.tensor(0, type: {:u, 8})
one_c = Complex.new(1, 0)
zero_c = Complex.new(0, 0)

assert Nx.equal(one_r, one_r) == one_u8
assert Nx.equal(one_r, zero_r) == zero_u8

assert Nx.equal(one_c, one_c) == one_u8
assert Nx.equal(one_c, zero_c) == zero_u8

assert Nx.equal(one_r, one_c) == one_u8
assert Nx.equal(zero_r, one_c) == zero_u8

assert Nx.equal(one_c, one_r) == one_u8
assert Nx.equal(one_c, zero_r) == zero_u8

assert Nx.equal(:nan, one_r) == zero_u8
assert Nx.equal(:nan, one_c) == zero_u8

assert Nx.equal(one_r, :nan) == zero_u8
assert Nx.equal(one_c, :nan) == zero_u8
end

test "not_equal" do
one_r = 1
one_u8 = Nx.tensor(1, type: {:u, 8})
zero_r = 0
zero_u8 = Nx.tensor(0, type: {:u, 8})
one_c = Complex.new(1, 0)
zero_c = Complex.new(0, 0)

assert Nx.not_equal(one_r, one_r) == zero_u8
assert Nx.not_equal(one_r, zero_r) == one_u8

assert Nx.not_equal(one_c, one_c) == zero_u8
assert Nx.not_equal(one_c, zero_c) == one_u8

assert Nx.not_equal(one_r, one_c) == zero_u8
assert Nx.not_equal(zero_r, one_c) == one_u8

assert Nx.not_equal(one_c, one_r) == zero_u8
assert Nx.not_equal(one_c, zero_r) == one_u8

assert Nx.not_equal(:nan, one_r) == one_u8
assert Nx.not_equal(:nan, one_c) == one_u8

assert Nx.not_equal(one_r, :nan) == one_u8
assert Nx.not_equal(one_c, :nan) == one_u8
end
end

describe "LinAlg not yet implemented" do
Expand Down
5 changes: 4 additions & 1 deletion nx/test/nx/defn/grad_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,10 @@ defmodule Nx.Defn.GradTest do
lhs = grad_mean_conv_y_general_stride_rhs_dilated(x, y)

rhs =
Nx.tensor([[[[7.4000006, 8.2], [7.4000006, 8.2]]], [[[7.4000006, 8.2], [7.4000006, 8.2]]]])
Nx.tensor([
[[[7.4000006, 8.2], [7.4000006, 8.2]]],
[[[7.4000006, 8.2], [7.4000006, 8.2]]]
])

assert_all_close(lhs, rhs)
end
Expand Down

0 comments on commit 32bf7e8

Please sign in to comment.