diff --git a/src/Reactant.jl b/src/Reactant.jl index 10155b75c..81a57eb2c 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -65,8 +65,8 @@ include("XLA.jl") include("Interpreter.jl") include("utils.jl") include("ConcreteRArray.jl") -include("TracedRArray.jl") include("TracedRNumber.jl") +include("TracedRArray.jl") include("Tracing.jl") include("Compiler.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f68dd51c2..4b3dd861f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -213,10 +213,10 @@ end promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) -struct TypeCast{T<:Number} <: Function end - -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x -function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number} +elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitives} = x +function elem_apply( + ::Type{T}, x::TracedRArray{T2} +) where {T<:ReactantPrimitives,T2<:ReactantPrimitives} # Special Path to prevent going down a despecialized path return elem_apply(TypeCast{T}(), x) end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index e0f44ba84..a0ff140c6 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -158,6 +158,8 @@ for (jlop, hloop) in ( end end +struct TypeCast{T<:ReactantPrimitives} <: Function end + (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x)