Skip to content

Commit 5826779

Browse files
authored
fix: upcast floats in cond (#1573)
1 parent 32bf7e8 commit 5826779

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

nx/lib/nx.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ defmodule Nx do
943943

944944
for t <-
945945
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
946-
[:f8, :bf16, :f16, :f32, :f64] do
946+
[:f8, :bf16, :f16, :f32, :f64, :c64, :c128] do
947947
@doc """
948948
Short-hand function for creating tensor of type `#{t}`.
949949

nx/lib/nx/defn/expr.ex

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,18 @@ defmodule Nx.Defn.Expr do
249249

250250
result =
251251
for expr <- [last | exprs] do
252-
expr
253-
|> Nx.as_type(type)
254-
|> Nx.broadcast(shape, names: names)
252+
typed_expr =
253+
case expr do
254+
%T{data: %Expr{op: :constant}} ->
255+
expr
256+
|> maybe_upcast_float_constant(type)
257+
|> Nx.as_type(type)
258+
259+
expr ->
260+
Nx.as_type(expr, type)
261+
end
262+
263+
Nx.broadcast(typed_expr, shape, names: names)
255264
end
256265

257266
{result, vectorized_axes}
@@ -1401,6 +1410,10 @@ defmodule Nx.Defn.Expr do
14011410
defp constant(%{shape: shape, type: type} = out, number) do
14021411
number =
14031412
cond do
1413+
Nx.Type.complex?(type) and
1414+
(is_number(number) or number in [:infinity, :neg_infinity, :nan]) ->
1415+
Complex.new(number, 0.0)
1416+
14041417
is_integer(number) and Nx.Type.float?(type) ->
14051418
Complex.multiply(1.0, number)
14061419

nx/test/nx/defn_test.exs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,48 @@ defmodule Nx.DefnTest do
11761176
)
11771177
end
11781178

1179+
defn cond_upcast_float_literals(n) do
1180+
cond do
1181+
n == 1 -> 1.4
1182+
n == 2 -> 2
1183+
true -> n
1184+
end
1185+
end
1186+
1187+
test "upcasts float literals based on the accumulated clause type" do
1188+
for input_type <- [f: 32, f: 64] do
1189+
assert %T{
1190+
type: ^input_type,
1191+
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
1192+
} =
1193+
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))
1194+
1195+
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1
1196+
assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [2.0]}}} = clause2
1197+
end
1198+
1199+
for input_type <- [c: 64, c: 128] do
1200+
assert %T{
1201+
type: ^input_type,
1202+
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
1203+
} =
1204+
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))
1205+
1206+
assert {_,
1207+
%T{
1208+
type: ^input_type,
1209+
data: %Expr{op: :constant, args: [%Complex{re: 1.4, im: +0.0}]}
1210+
}} = clause1
1211+
1212+
assert {_,
1213+
%T{
1214+
type: ^input_type,
1215+
data: %Expr{op: :constant, args: [%Complex{re: 2.0, im: +0.0}]}
1216+
}} =
1217+
clause2
1218+
end
1219+
end
1220+
11791221
defn cond_list(a) do
11801222
if Nx.any(a), do: 1, else: -1
11811223
end

0 commit comments

Comments
 (0)