diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index ef0b5c3a82..638891eaf1 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1468,16 +1468,42 @@ defmodule Nx.Defn.Expr do c1 = maybe_constant(arg1) c2 = maybe_constant(arg2) - if c1 && c2 do - apply(Nx.BinaryBackend, op, [ - %{out | shape: {}, names: []}, - constant_binary(arg1, c1), - constant_binary(arg2, c2) - ]) - |> Nx.to_number() - |> then(&constant(out, &1)) + cond do + c1 && c2 -> + apply(Nx.BinaryBackend, op, [ + %{out | shape: {}, names: []}, + constant_binary(arg1, c1), + constant_binary(arg2, c2) + ]) + |> Nx.to_number() + |> then(&constant(out, &1)) + + c1 -> + expr(out, context, op, [maybe_upcast_float_constant(arg1, out.type), arg2]) + + c2 -> + expr(out, context, op, [arg1, maybe_upcast_float_constant(arg2, out.type)]) + + true -> + expr(out, context, op, [arg1, arg2]) + end + end + + defp maybe_upcast_float_constant( + %T{type: type, data: %Expr{op: :constant, args: [number]}} = t, + out_type + ) do + # By default, Elixir floats are 64 bits, so we're not really upcasting + # if out_type is higher precision than what's annotated. + # This is just so that downstream code that relies on this type annotation + # properly interprets the f64 value as the higher precision type. + # This also means that if out_type is lower precision, `number` will be + # downcast to the lower precision type. + + if Nx.Type.float?(type) and Nx.Type.float?(out_type) do + constant(%{t | type: out_type}, number) else - expr(out, context, op, [arg1, arg2]) + t end end diff --git a/nx/lib/nx/type.ex b/nx/lib/nx/type.ex index 2097693d6a..c529d09c31 100644 --- a/nx/lib/nx/type.ex +++ b/nx/lib/nx/type.ex @@ -373,7 +373,9 @@ defmodule Nx.Type do bits. Otherwise it casts to f64. In the case of complex numbers, the maximum bit size is 128 bits - because they are composed of two floats. + because they are composed of two floats. Float types are promoted + to c64 by default, with the exception of f64, which is promoted to + c128 so that a single component can represent an f64 number properly. ## Examples @@ -429,8 +431,15 @@ defmodule Nx.Type do iex> Nx.Type.merge({:f, 64}, {:bf, 16}) {:f, 64} + iex> Nx.Type.merge({:f, 16}, {:c, 64}) + {:c, 64} + iex> Nx.Type.merge({:f, 32}, {:c, 64}) + {:c, 64} + iex> Nx.Type.merge({:f, 64}, {:c, 64}) + {:c, 128} iex> Nx.Type.merge({:c, 64}, {:f, 32}) {:c, 64} + iex> Nx.Type.merge({:c, 64}, {:c, 64}) {:c, 64} iex> Nx.Type.merge({:c, 128}, {:c, 64}) @@ -443,6 +452,7 @@ defmodule Nx.Type do def merge(left, right) do case sort(left, right) do {{:u, size1}, {:s, size2}} -> {:s, max(min(size1 * 2, 64), size2)} + {{:f, size1}, {:c, size2}} -> {:c, max(size1 * 2, size2)} {_, type2} -> type2 end end diff --git a/nx/test/nx/defn/expr_test.exs b/nx/test/nx/defn/expr_test.exs index e901c83392..e579b39211 100644 --- a/nx/test/nx/defn/expr_test.exs +++ b/nx/test/nx/defn/expr_test.exs @@ -198,6 +198,29 @@ defmodule Nx.Defn.ExprTest do c = metadata b, :stop_grad s32[1] """ end + + test "upcast float constants when operating against higher precision types" do + t_f32 = Nx.tensor([2, 2], type: :f32) |> Expr.tensor() + c_f64 = Expr.constant(Nx.tensor(0.7, type: :f64), 0.7, []) + + assert %T{type: {:f, 64}, data: %Expr{op: :multiply, args: [^c_f64, ^t_f32]}} = + Nx.multiply(t_f32, c_f64) + + t_f64 = Nx.tensor([2, 2], type: :f64) |> Expr.tensor() + c_f32 = Expr.constant(Nx.tensor(0.7, type: :f32), 0.7, []) + + assert %T{type: {:f, 64}, data: %Expr{op: :multiply, args: [^c_f64, ^t_f64]}} = + Nx.multiply(t_f64, c_f32) + + c_c64 = Expr.constant(Nx.tensor(0.7, type: :c64), 0.7, []) + c_c128 = Expr.constant(Nx.tensor(0.7, type: :c128), 0.7, []) + + assert %T{type: {:c, 64}, data: %Expr{op: :multiply, args: [^c_c64, ^t_f32]}} = + Nx.multiply(t_f32, c_c64) + + assert %T{type: {:c, 128}, data: %Expr{op: :multiply, args: [^c_c128, ^t_f64]}} = + Nx.multiply(t_f64, c_c64) + end end describe "inspect" do