Skip to content

Commit

Permalink
feat: allow complex literals in defn (#1572)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Jan 23, 2025
1 parent 5826779 commit 9cfcd05
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
7 changes: 7 additions & 0 deletions nx/lib/nx/defn/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,13 @@ defmodule Nx.Defn.Compiler do
{{{:., dot_meta, [Nx, name]}, meta, args}, state}
end

# We also allow specifically Complex.new so that literal complex numbers
# can be written in defn.
defp normalize({{:., dot_meta, [Complex, :new]}, meta, args}, state) do
{args, state} = normalize_list(args, state)
{{{:., dot_meta, [Complex, :new]}, meta, args}, state}
end

defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do
{args, state} = normalize_list(args, state)
{{{:., dot_meta, [mod, name]}, meta, args}, state}
Expand Down
2 changes: 1 addition & 1 deletion nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ defmodule Nx.Defn.Expr do
"value and inline it inside the defn expression. Got: #{inspect(t)}"
end

defp to_expr(number) when is_number(number),
defp to_expr(number) when is_number(number) or is_struct(number, Complex),
do: constant(%T{shape: {}, names: [], type: Nx.Type.infer(number)}, number)

defp to_expr(other) do
Expand Down
9 changes: 9 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ defmodule Nx.DefnTest do
@tensor [1, 2, 3]
defn(list_constant, do: Nx.tensor(@tensor))

defn complex_constant do
Complex.new(1, :infinity)
end

test "from list" do
assert %T{data: %Expr{op: :tensor}} = list_constant()
end
Expand All @@ -35,6 +39,11 @@ defmodule Nx.DefnTest do
test "from binary" do
assert %T{data: %Expr{op: :tensor}} = binary_constant()
end

test "complex literals" do
assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant()
assert c == Complex.new(1, :infinity)
end
end

describe "Nx.tensor" do
Expand Down

0 comments on commit 9cfcd05

Please sign in to comment.