Skip to content

Commit b01e94a

Browse files
authored
fix: Nx.Defn.Expr will now promote constants based on the surrounding non-constant tensors for binary operations (#1570)
1 parent 8dc7b29 commit b01e94a

File tree

3 files changed

+69
-10
lines changed

3 files changed

+69
-10
lines changed

nx/lib/nx/defn/expr.ex

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,16 +1468,42 @@ defmodule Nx.Defn.Expr do
14681468
c1 = maybe_constant(arg1)
14691469
c2 = maybe_constant(arg2)
14701470

1471-
if c1 && c2 do
1472-
apply(Nx.BinaryBackend, op, [
1473-
%{out | shape: {}, names: []},
1474-
constant_binary(arg1, c1),
1475-
constant_binary(arg2, c2)
1476-
])
1477-
|> Nx.to_number()
1478-
|> then(&constant(out, &1))
1471+
cond do
1472+
c1 && c2 ->
1473+
apply(Nx.BinaryBackend, op, [
1474+
%{out | shape: {}, names: []},
1475+
constant_binary(arg1, c1),
1476+
constant_binary(arg2, c2)
1477+
])
1478+
|> Nx.to_number()
1479+
|> then(&constant(out, &1))
1480+
1481+
c1 ->
1482+
expr(out, context, op, [maybe_upcast_float_constant(arg1, out.type), arg2])
1483+
1484+
c2 ->
1485+
expr(out, context, op, [arg1, maybe_upcast_float_constant(arg2, out.type)])
1486+
1487+
true ->
1488+
expr(out, context, op, [arg1, arg2])
1489+
end
1490+
end
1491+
1492+
defp maybe_upcast_float_constant(
1493+
%T{type: type, data: %Expr{op: :constant, args: [number]}} = t,
1494+
out_type
1495+
) do
1496+
# By default, Elixir floats are 64 bits, so we're not really upcasting
1497+
# if out_type is higher precision than what's annotated.
1498+
# This is just so that downstream code that relies on this type annotation
1499+
# properly interprets the f64 value as the higher precision type.
1500+
# This also means that if out_type is lower precision, `number` will be
1501+
# downcast to the lower precision type.
1502+
1503+
if Nx.Type.float?(type) and Nx.Type.float?(out_type) do
1504+
constant(%{t | type: out_type}, number)
14791505
else
1480-
expr(out, context, op, [arg1, arg2])
1506+
t
14811507
end
14821508
end
14831509

nx/lib/nx/type.ex

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ defmodule Nx.Type do
373373
bits. Otherwise it casts to f64.
374374
375375
In the case of complex numbers, the maximum bit size is 128 bits
376-
because they are composed of two floats.
376+
because they are composed of two floats. Float types are promoted
377+
to c64 by default, with the exception of f64, which is promoted to
378+
c128 so that a single component can represent an f64 number properly.
377379
378380
## Examples
379381
@@ -429,8 +431,15 @@ defmodule Nx.Type do
429431
iex> Nx.Type.merge({:f, 64}, {:bf, 16})
430432
{:f, 64}
431433
434+
iex> Nx.Type.merge({:f, 16}, {:c, 64})
435+
{:c, 64}
436+
iex> Nx.Type.merge({:f, 32}, {:c, 64})
437+
{:c, 64}
438+
iex> Nx.Type.merge({:f, 64}, {:c, 64})
439+
{:c, 128}
432440
iex> Nx.Type.merge({:c, 64}, {:f, 32})
433441
{:c, 64}
442+
434443
iex> Nx.Type.merge({:c, 64}, {:c, 64})
435444
{:c, 64}
436445
iex> Nx.Type.merge({:c, 128}, {:c, 64})
@@ -443,6 +452,7 @@ defmodule Nx.Type do
443452
def merge(left, right) do
444453
case sort(left, right) do
445454
{{:u, size1}, {:s, size2}} -> {:s, max(min(size1 * 2, 64), size2)}
455+
{{:f, size1}, {:c, size2}} -> {:c, max(size1 * 2, size2)}
446456
{_, type2} -> type2
447457
end
448458
end

nx/test/nx/defn/expr_test.exs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,29 @@ defmodule Nx.Defn.ExprTest do
198198
c = metadata b, :stop_grad s32[1]
199199
"""
200200
end
201+
202+
test "upcast float constants when operating against higher precision types" do
203+
t_f32 = Nx.tensor([2, 2], type: :f32) |> Expr.tensor()
204+
c_f64 = Expr.constant(Nx.tensor(0.7, type: :f64), 0.7, [])
205+
206+
assert %T{type: {:f, 64}, data: %Expr{op: :multiply, args: [^c_f64, ^t_f32]}} =
207+
Nx.multiply(t_f32, c_f64)
208+
209+
t_f64 = Nx.tensor([2, 2], type: :f64) |> Expr.tensor()
210+
c_f32 = Expr.constant(Nx.tensor(0.7, type: :f32), 0.7, [])
211+
212+
assert %T{type: {:f, 64}, data: %Expr{op: :multiply, args: [^c_f64, ^t_f64]}} =
213+
Nx.multiply(t_f64, c_f32)
214+
215+
c_c64 = Expr.constant(Nx.tensor(0.7, type: :c64), 0.7, [])
216+
c_c128 = Expr.constant(Nx.tensor(0.7, type: :c128), 0.7, [])
217+
218+
assert %T{type: {:c, 64}, data: %Expr{op: :multiply, args: [^c_c64, ^t_f32]}} =
219+
Nx.multiply(t_f32, c_c64)
220+
221+
assert %T{type: {:c, 128}, data: %Expr{op: :multiply, args: [^c_c128, ^t_f64]}} =
222+
Nx.multiply(t_f64, c_c64)
223+
end
201224
end
202225

203226
describe "inspect" do

0 commit comments

Comments
 (0)