From c31ecb94b59908b07d275265152cd61fadac0ced Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:43:32 -0300 Subject: [PATCH] fix: use Nx.as_type as well --- nx/lib/nx/defn/expr.ex | 13 +++++++++++-- nx/test/nx/defn_test.exs | 25 +++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 4e495e2f49..6386769e22 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -251,8 +251,13 @@ defmodule Nx.Defn.Expr do for expr <- [last | exprs] do typed_expr = case expr do - %T{data: %Expr{op: :constant}} -> maybe_upcast_float_constant(expr, type) - expr -> Nx.as_type(expr, type) + %T{data: %Expr{op: :constant}} -> + expr + |> maybe_upcast_float_constant(type) + |> Nx.as_type(type) + + expr -> + Nx.as_type(expr, type) end Nx.broadcast(typed_expr, shape, names: names) @@ -1405,6 +1410,10 @@ defmodule Nx.Defn.Expr do defp constant(%{shape: shape, type: type} = out, number) do number = cond do + Nx.Type.complex?(type) and + (is_number(number) or number in [:infinity, :neg_infinity, :nan]) -> + Complex.new(number, 0.0) + is_integer(number) and Nx.Type.float?(type) -> Complex.multiply(1.0, number) diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 568e186050..42bc0f8dcb 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -1185,7 +1185,7 @@ defmodule Nx.DefnTest do end test "upcasts float literals based on the accumulated clause type" do - for input_type <- [f: 32, f: 64, c: 64, c: 128] do + for input_type <- [f: 32, f: 64] do assert %T{ type: ^input_type, data: %Expr{op: :cond, args: [[clause1, clause2], _last]} @@ -1193,7 +1193,28 @@ defmodule Nx.DefnTest do cond_upcast_float_literals(Nx.tensor(10.0, type: input_type)) assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1 - assert {_, %T{type: {:s, 32}, data: %Expr{op: :constant, args: [2]}}} = clause2 + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [2.0]}}} = clause2 + end + + for input_type <- [c: 64, c: 128] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + cond_upcast_float_literals(Nx.tensor(10.0, type: input_type)) + + assert {_, + %T{ + type: ^input_type, + data: %Expr{op: :constant, args: [%Complex{re: 1.4, im: +0.0}]} + }} = clause1 + + assert {_, + %T{ + type: ^input_type, + data: %Expr{op: :constant, args: [%Complex{re: 2.0, im: +0.0}]} + }} = + clause2 end end