Skip to content

Commit 9cfcd05

Browse files
authored
feat: allow complex literals in defn (#1572)
1 parent 5826779 commit 9cfcd05

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

nx/lib/nx/defn/compiler.ex

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,13 @@ defmodule Nx.Defn.Compiler do
585585
{{{:., dot_meta, [Nx, name]}, meta, args}, state}
586586
end
587587

588+
# We also allow specifically Complex.new so that literal complex numbers
589+
# can be written in defn.
590+
defp normalize({{:., dot_meta, [Complex, :new]}, meta, args}, state) do
591+
{args, state} = normalize_list(args, state)
592+
{{{:., dot_meta, [Complex, :new]}, meta, args}, state}
593+
end
594+
588595
defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do
589596
{args, state} = normalize_list(args, state)
590597
{{{:., dot_meta, [mod, name]}, meta, args}, state}

nx/lib/nx/defn/expr.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ defmodule Nx.Defn.Expr do
12801280
"value and inline it inside the defn expression. Got: #{inspect(t)}"
12811281
end
12821282

1283-
defp to_expr(number) when is_number(number),
1283+
defp to_expr(number) when is_number(number) or is_struct(number, Complex),
12841284
do: constant(%T{shape: {}, names: [], type: Nx.Type.infer(number)}, number)
12851285

12861286
defp to_expr(other) do

nx/test/nx/defn_test.exs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ defmodule Nx.DefnTest do
2525
@tensor [1, 2, 3]
2626
defn(list_constant, do: Nx.tensor(@tensor))
2727

28+
defn complex_constant do
29+
Complex.new(1, :infinity)
30+
end
31+
2832
test "from list" do
2933
assert %T{data: %Expr{op: :tensor}} = list_constant()
3034
end
@@ -35,6 +39,11 @@ defmodule Nx.DefnTest do
3539
test "from binary" do
3640
assert %T{data: %Expr{op: :tensor}} = binary_constant()
3741
end
42+
43+
test "complex literals" do
44+
assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant()
45+
assert c == Complex.new(1, :infinity)
46+
end
3847
end
3948

4049
describe "Nx.tensor" do

0 commit comments

Comments
 (0)