Skip to content

Commit

Permalink
fix: use Nx.as_type as well
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Jan 23, 2025
1 parent 6b9d819 commit c31ecb9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
13 changes: 11 additions & 2 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,13 @@ defmodule Nx.Defn.Expr do
for expr <- [last | exprs] do
typed_expr =
case expr do
%T{data: %Expr{op: :constant}} -> maybe_upcast_float_constant(expr, type)
expr -> Nx.as_type(expr, type)
%T{data: %Expr{op: :constant}} ->
expr
|> maybe_upcast_float_constant(type)
|> Nx.as_type(type)

expr ->
Nx.as_type(expr, type)
end

Nx.broadcast(typed_expr, shape, names: names)
Expand Down Expand Up @@ -1405,6 +1410,10 @@ defmodule Nx.Defn.Expr do
defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
Nx.Type.complex?(type) and
(is_number(number) or number in [:infinity, :neg_infinity, :nan]) ->
Complex.new(number, 0.0)

is_integer(number) and Nx.Type.float?(type) ->
Complex.multiply(1.0, number)

Expand Down
25 changes: 23 additions & 2 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1185,15 +1185,36 @@ defmodule Nx.DefnTest do
end

test "upcasts float literals based on the accumulated clause type" do
for input_type <- [f: 32, f: 64, c: 64, c: 128] do
for input_type <- [f: 32, f: 64] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))

assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1
assert {_, %T{type: {:s, 32}, data: %Expr{op: :constant, args: [2]}}} = clause2
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [2.0]}}} = clause2
end

for input_type <- [c: 64, c: 128] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))

assert {_,
%T{
type: ^input_type,
data: %Expr{op: :constant, args: [%Complex{re: 1.4, im: +0.0}]}
}} = clause1

assert {_,
%T{
type: ^input_type,
data: %Expr{op: :constant, args: [%Complex{re: 2.0, im: +0.0}]}
}} =
clause2
end
end

Expand Down

0 comments on commit c31ecb9

Please sign in to comment.