diff --git a/src/Ops.jl b/src/Ops.jl index 11c75820a..a873de4ca 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -548,6 +548,18 @@ end # ) # return TracedRArray{T,N}((), res, size(x)) # end +@noinline function bitcast_convert( + ::Type{U}, + x::TracedRNumber{T}; + location=mlir_stacktrace("bitcast_convert", @__FILE__, @__LINE__), +) where {T,U} + res = MLIR.IR.result( + stablehlo.bitcast_convert( + x.mlir_data; result_0=mlir_type(TracedRArray{U,0}, ()), location + ), + ) + return TracedRNumber{U}((), res) +end @noinline function fft( x::TracedRArray{T,N}; diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 75bfc708c..e7410a0b4 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -94,7 +94,6 @@ for (jlop, hloop) in ( (:(Base.:*), :multiply), (:(Base.:/), :divide), (:(Base.:^), :power), - (:(Base.mod), :remainder), (:(Base.rem), :remainder), ) @eval function $(jlop)( @@ -109,13 +108,30 @@ function Base.rem( ) where {T} return Ops.remainder(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) end - function Base.rem( @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) ) where {T} return Ops.remainder(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) end +# Based on https://github.com/JuliaLang/julia/blob/39255d47db7657950ff1c82137ecec5a70bae622/base/float.jl#L608-L617 +function Base.mod( + @nospecialize(x::Reactant.TracedRNumber{T}), @nospecialize(y::Reactant.TracedRNumber{T}) +) where {T} + r = rem(x, y) + return ifelse(r == 0, copysign(r, y), ifelse((r > 0) ⊻ (y > 0), r + y, r)) +end +function Base.mod( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) +) where {T} + return mod(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) +end +function Base.mod( + @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) +) where {T} + return mod(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) +end + function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer} return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) end @@ -224,6 +240,12 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end + function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return Ops.xor( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) end end @@ -391,4 +413,20 @@ function Base.typed_hvncat( return Base.typed_hvncat(T, dims, row_first, xs...) end +for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64)) + @eval begin + Base.signbit(x::TracedRNumber{$(Ti)}) = x < 0 + Base.signbit(x::TracedRNumber{$(Tf)}) = signbit(Ops.bitcast_convert($(Ti), x)) + end +end +Base.signbit(::TracedRNumber{<:Unsigned}) = ConcreteRNumber(false) + +Base.copysign(x::TracedRNumber, y::TracedRNumber) = ifelse(signbit(y), -1, 1) * abs(x) +function Base.copysign(x::TracedRNumber{T}, y::S) where {T,S<:Number} + return copysign(x, TracedUtils.promote_to(TracedRNumber{S}, y)) end +function Base.copysign(x::S, y::TracedRNumber{T}) where {S<:Number,T} + return copysign(TracedUtils.promote_to(TracedRNumber{S}, x), y) +end + +end # module TracedRNumberOverrides diff --git a/test/basic.jl b/test/basic.jl index 08c6bff81..99a19ceb0 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -829,13 +829,10 @@ end a = [-1.1, 7.7, -3.3, 9.9, -5.5] b = [6.6, -2.2, -8.8, 4.4, -10.1] - # Currently broken because `mod` is JIT-ed to an HLO operator with same semantic as - # Julia's `rem`, rather than `mod`. expected_mod = mod.(a, b) - @test_broken Reactant.@jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ - expected_mod - @test_broken Reactant.@jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod - @test_broken Reactant.@jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod + @test Reactant.@jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_mod + @test Reactant.@jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod + @test Reactant.@jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod expected_rem = rem.(a, b) @test Reactant.@jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_rem @@ -843,6 +840,26 @@ end @test Reactant.@jit(rem.(Reactant.to_rarray(a), b)) ≈ expected_rem end +@testset "xor" begin + for a in (true, false), b in (true, false) + @test @jit(xor(ConcreteRNumber(a), ConcreteRNumber(b))) == xor(a, b) + end +end + +@testset "signbit" begin + for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0) + @test @jit(signbit(ConcreteRNumber(x))) == signbit(x) + end +end + +@testset "copysign" begin + for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14) + # Make sure also the return type is correct + @test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) === + copysign(a, b) + end +end + @testset "reduce integers" begin x = rand(Bool, 100) x_ra = Reactant.to_rarray(x)