Skip to content

Commit

Permalink
fix: upcast floats in cond (#1573)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Jan 23, 2025
1 parent 32bf7e8 commit 5826779
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
2 changes: 1 addition & 1 deletion nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ defmodule Nx do

for t <-
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
[:f8, :bf16, :f16, :f32, :f64] do
[:f8, :bf16, :f16, :f32, :f64, :c64, :c128] do
@doc """
Short-hand function for creating tensor of type `#{t}`.
Expand Down
19 changes: 16 additions & 3 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,18 @@ defmodule Nx.Defn.Expr do

result =
for expr <- [last | exprs] do
expr
|> Nx.as_type(type)
|> Nx.broadcast(shape, names: names)
typed_expr =
case expr do
%T{data: %Expr{op: :constant}} ->
expr
|> maybe_upcast_float_constant(type)
|> Nx.as_type(type)

expr ->
Nx.as_type(expr, type)
end

Nx.broadcast(typed_expr, shape, names: names)
end

{result, vectorized_axes}
Expand Down Expand Up @@ -1401,6 +1410,10 @@ defmodule Nx.Defn.Expr do
defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
Nx.Type.complex?(type) and
(is_number(number) or number in [:infinity, :neg_infinity, :nan]) ->
Complex.new(number, 0.0)

is_integer(number) and Nx.Type.float?(type) ->
Complex.multiply(1.0, number)

Expand Down
42 changes: 42 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,48 @@ defmodule Nx.DefnTest do
)
end

defn cond_upcast_float_literals(n) do
cond do
n == 1 -> 1.4
n == 2 -> 2
true -> n
end
end

test "upcasts float literals based on the accumulated clause type" do
for input_type <- [f: 32, f: 64] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))

assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [2.0]}}} = clause2
end

for input_type <- [c: 64, c: 128] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))

assert {_,
%T{
type: ^input_type,
data: %Expr{op: :constant, args: [%Complex{re: 1.4, im: +0.0}]}
}} = clause1

assert {_,
%T{
type: ^input_type,
data: %Expr{op: :constant, args: [%Complex{re: 2.0, im: +0.0}]}
}} =
clause2
end
end

defn cond_list(a) do
if Nx.any(a), do: 1, else: -1
end
Expand Down

0 comments on commit 5826779

Please sign in to comment.