diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index bac3ee75ff..f7ee82c204 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -1,31 +1,24 @@ module ReactantNNlibExt using NNlib -using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR +using Reactant: + Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber for (jlop, hloop) in ( (:(NNlib.tanh_fast), :tanh), (:(NNlib.sigmoid_fast), :logistic), (:(NNlib.sigmoid), :logistic), ) - @eval function $(jlop)(x::TracedRArray{T,0}) where {T} - return TracedRArray{T,0}( + @eval function $(jlop)(x::TracedRNumber{T}) where {T} + return TracedRNumber{T}( (), Reactant.MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 ), - (), ) end end -# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays -for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, :σ)) - @eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T} - return invoke(NNlib.$(nnlib_op), Tuple{Any}, x) - end -end - # TODO handle non finite cases function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} max_ = NNlib.fast_maximum(x; dims) @@ -39,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where return out ./= tmp end +function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T} + max_ = NNlib.fast_maximum(x; dims) + # if all(isfinite, max_) + @fastmath out .= x .- max_ + # else + # _zero, _minf, _inf = T(0), T(-Inf), T(Inf) + # @. out = ifelse( + # isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_ + # ) + # end + @fastmath log_ = log.(sum(exp, out; dims)) + return out .-= log_ +end + function NNlib.conv( x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims ) where {T,N} diff --git a/src/Compiler.jl b/src/Compiler.jl index 94a4f35a35..4b07ab2b10 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -6,10 +6,12 @@ import ..Reactant: XLA, ConcreteRArray, TracedRArray, + TracedRNumber, OrderedIdDict, make_tracer, TracedToConcrete, - append_path + append_path, + TracedType @inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) @@ -286,10 +288,10 @@ function compile_mlir!(mod, f, args; optimize=true) ) end - preserved_args = Tuple{TracedRArray,Int}[] + preserved_args = Tuple{TracedType,Int}[] results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] - linear_results2 = TracedRArray[] + linear_results2 = TracedType[] for (i, op) in enumerate(results) if !MLIR.IR.is_block_arg(op) push!(nresults, op) @@ -573,7 +575,7 @@ end function compile_xla(f, args; client=nothing) # register MLIR dialects ctx = MLIR.IR.Context() - Base.append!(Reactant.registry[]; context=ctx) + append!(Reactant.registry[]; context=ctx) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid return MLIR.IR.context!(ctx) do diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 8f70f324d3..9c2077b27b 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -74,10 +74,6 @@ function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T} return to_float(x) end -function Base.promote_rule(::Type{<:RArray{T1,0}}, ::Type{T2}) where {T1,T2} - return Base.promote_rule(T1, T2) -end - for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^)) @eval begin function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U} @@ -158,7 +154,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} end function mysetindex!(a, v, args::Vararg{Int,N}) where {N} - Base.setindex!(a, v, args...) + setindex!(a, v, args...) return nothing end diff --git a/src/Reactant.jl b/src/Reactant.jl index 92b4c9cdfd..82bf06ea60 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -7,7 +7,45 @@ include("OrderedIdDict.jl") using Enzyme -abstract type RArray{T,N} <: AbstractArray{T,N} end +@static if isdefined(Core, :BFloat16) + const ReactantPrimitive = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Core.BFloat16, + Float32, + Float64, + Complex{Float32}, + Complex{Float64}, + } +else + const ReactantPrimitive = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Float32, + Float64, + Complex{Float32}, + Complex{Float64}, + } +end + +abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end +abstract type RNumber{T<:ReactantPrimitive} <: Number end function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) return reshape(A, Base._reshape_uncolon(A, dims)) @@ -45,8 +83,13 @@ include("mlir/MLIR.jl") include("XLA.jl") include("Interpreter.jl") include("utils.jl") + include("ConcreteRArray.jl") +include("TracedRNumber.jl") include("TracedRArray.jl") + +const TracedType = Union{TracedRArray,TracedRNumber} + include("Tracing.jl") include("Compiler.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2780179098..bafea85a9f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -21,7 +21,6 @@ TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRScalar{T} = AnyTracedRArray{T,0} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} @@ -40,15 +39,6 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -Base.getindex(a::AnyTracedRScalar{T}) where {T} = a - -Base.zero(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, zero(T)) -Base.one(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, one(T)) - -function Base.convert(::Type{<:AnyTracedRScalar{T}}, x::Number) where {T} - return promote_to(TracedRArray{T,0}, T(x)) -end - function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} @warn( """Performing scalar indexing on task $(current_task()). @@ -73,7 +63,11 @@ and require expensive copies and synchronization each time and therefore should ), 1, ) - return TracedRArray{T,0}((), res2, ()) + return TracedRNumber{T}((), res2) +end + +function Base.getindex(a::TracedRArray{T,0}) where {T} + return TracedRNumber{T}((), a.mlir_data) end function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} @@ -105,13 +99,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...) end function Base.setindex!( - a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N} + a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} ) where {T,N} + indices = map(enumerate(indices)) do (idx, i) + i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) + end + v = broadcast_to_size(v, length.(indices)) + v = promote_to(TracedRArray{T,N}, v) indices = [ - (promote_to(TracedRArray{Int,0}, i isa Colon ? 1 : first(i)) - 1).mlir_data for + (promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for i in indices ] - v = promote_to(TracedRArray{T,N}, v) res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_update_slice(a.mlir_data, v.mlir_data, indices), 1 ) @@ -133,8 +131,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC # return print(io, X.mlir_data, ")") end -Base.only(A::AnyTracedRScalar{T}) where {T} = A - function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} if prod(dims) != prod(size(A)) throw( @@ -195,21 +191,9 @@ function Base.transpose(A::AnyTracedRVecOrMat) end Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A) -function Base.promote_rule( - ::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}} -) where {T,S,N} - return TracedRArray{Base.promote_type(T, S),N} -end - -function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N} - return TracedRArray{Base.promote_type(T, S),N} -end - function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} if isa(rhs, TracedRArray) - if typeof(rhs) == TracedRArray{T,N} - return rhs - end + rhs isa TracedRArray{T,N} && return rhs return TracedRArray{T,N}( (), MLIR.IR.result( @@ -222,11 +206,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} ) end if isa(rhs, Number) - attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}, size(rhs))) - ta = TracedRArray{T,N}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs) - ) - return ta + throw(ArgumentError("Cannot promote number to `TracedRArray`. Use \ + `TracedRNumber` instead.")) end T0 = eltype(rhs) attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) @@ -238,120 +219,23 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} ) end -function promote_to(::TracedRArray{T,N}, rhs) where {T,N} - return promote_to(TracedRArray{T,N}, rhs) -end - -for (jlop, hloop) in ( - (:(Base.min), :minimum), - (:(Base.max), :maximum), - (:(Base.:+), :add), - (:(Base.:-), :subtract), - (:(Base.:*), :multiply), - (:(Base.:/), :divide), - (:(Base.:^), :power), -) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - return TracedRArray{T,0}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 - ), - (), - ) - end - - function $(jlop)( - @nospecialize(lhs::TracedRArray{T1,0}), @nospecialize(rhs::TracedRArray{T2,0}) - ) where {T1,T2} - commonTy = TracedRArray{Base.promote_type(T1, T2),0} - lhs = promote_to(commonTy, lhs) - rhs = promote_to(commonTy, rhs) - return $(jlop)(lhs, rhs) - end - end - - for otherType in (Number, Any) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::$(otherType)) - ) where {T} - rhs = promote_to(lhs, rhs) - return $(jlop)(lhs, rhs) - end - - function $(jlop)( - @nospecialize(lhs::$(otherType)), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - lhs = promote_to(rhs, lhs) - return $(jlop)(lhs, rhs) - end - end - end -end - -function Base.ifelse( - @nospecialize(pred::TracedRArray{Bool,0}), - @nospecialize(x::TracedRArray{T1,0}), - @nospecialize(y::TracedRArray{T2,0}) -) where {T1,T2} - return TracedRArray{promote_type(T1, T2),0}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 - ), - size(pred), - ) -end - -Base.abs2(x::Reactant.TracedRArray{T,0}) where {T} = x * conj(x) +promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) -function Base.literal_pow( - ::Base.RefValue{typeof(^)}, x::TracedRArray{T,0}, ::Base.RefValue{Val{P}} -) where {T,P} - return Base.literal_pow(^, x, Val(P)) -end - -for (jlop, hloop) in ( - (:(Base.abs), :abs), - (:(Base.:-), :negate), - (:(Base.sin), :sine), - (:(Base.cos), :cosine), - (:(Base.tanh), :tanh), - (:(Base.FastMath.tanh_fast), :tanh), - (:(Base.exp), :exponential), - (:(Base.FastMath.exp_fast), :exponential), - (:(Base.log), :log), - (:(Base.sqrt), :sqrt), -) - @eval begin - function $jlop(@nospecialize(lhs::TracedRArray{T,0})) where {T} - return TracedRArray{T,0}( - (), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1), - size(lhs), - ) - end - end -end - -struct TypeCast{T<:Number} <: Function end - -function (::TypeCast{T})(x::TracedRArray{T2,0}) where {T,T2} - return promote_to(TracedRArray{T,0}, x) -end - -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x -function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number} +elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x +function elem_apply( + ::Type{T}, x::TracedRArray{T2} +) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} # Special Path to prevent going down a despecialized path return elem_apply(TypeCast{T}(), x) end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} - all(iszero ∘ ndims, args) && return f(args...) + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return promote_to(TracedRNumber{eltype(arg)}, arg) + end + return f(scalar_args...) + end fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true @@ -362,7 +246,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} invmap[v] = k end - input_shapes = size.(keys(seen_args)) + keys_seen = [k for k in keys(seen_args) if k isa TracedType] + input_shapes = size.(keys_seen) # by the time we reach here all args must have same size @assert allequal(input_shapes) "input shapes are $(input_shapes)" OutShape = isempty(seen_args) ? nothing : first(input_shapes) @@ -428,56 +313,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} return traced2_result end -for (jlop, hloop, hlocomp, merge) in ( - (:(Base.:(==)), :compare, "EQ", :all), - (:(Base.:(!=)), :compare, "NE", :any), - (:(Base.:(>=)), :compare, "GE", nothing), - (:(Base.:(>)), :compare, "GT", nothing), - (:(Base.:(<=)), :compare, "LE", nothing), - (:(Base.:(<)), :compare, "LT", nothing), -) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - return TracedRArray{Bool,0}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), $hlocomp - ), - ), - 1, - ), - size(lhs), - ) - end - - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs) - ) where {T} - return $(jlop)(lhs, promote_to(lhs, rhs)) - end - - function $(jlop)( - @nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - return $(jlop)(promote_to(rhs, lhs), rhs) - end - end - - if merge !== nothing - @eval begin - function $jlop( - @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N}) - ) where {T,N} - elems = $(jlop).(lhs, rhs) - return N == 0 ? elems : $(merge)(elems) - end - end +for (jlop, hloop, hlocomp, merge) in + ((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any)) + @eval function $jlop( + @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N}) + ) where {T,N} + elems = $(jlop).(lhs, rhs) + return N == 0 ? elems : $(merge)(elems) end end @@ -556,8 +398,7 @@ function Base.mapreduce( fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) args = ( - TracedRArray{T,0}((), MLIR.IR.argument(fnbody, i), ()) for - (i, ty) in enumerate(in_tys) + TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys) ) res = MLIR.IR.block!(fnbody) do @@ -591,7 +432,11 @@ function Base.mapreduce( ) red = TracedRArray{T,length(toonedims)}((), red, (toonedims...,)) else - red = TracedRArray{T,length(outdims)}((), red, (outdims...,)) + if length(outdims) == 0 + red = TracedRNumber{T}((), red) + else + red = TracedRArray{T,length(outdims)}((), red, (outdims...,)) + end end return red end @@ -613,25 +458,31 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N} return A end +function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} + bcast = broadcast_to_size(promote_to(TracedRNumber{T}, x), size(A)) + A.mlir_data = bcast.mlir_data + return A +end + struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}() AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}() -# function Broadcast.materialize(bc::Broadcasted) -# @show bc -# inst = instantiate(bc) -# @show inst -# copy(inst) -# end - function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N} return AbstractReactantArrayStyle{N}() end function Base.similar( bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T,N} +) where {T<:ReactantPrimitive,N} + @assert N isa Int + return TracedRArray{T,N}((), nothing, map(length, dims)) +end + +function Base.similar( + bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{<:TracedRNumber{T}}, dims +) where {T<:ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,N}((), nothing, map(length, dims)) end @@ -708,6 +559,13 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number} ) end +function broadcast_to_size(arg::TracedRNumber, rsize) + length(rsize) == 0 && return arg + return broadcast_to_size_internal( + TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize + ) +end + function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) size(arg) == rsize && return arg diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl new file mode 100644 index 0000000000..b41dfcb481 --- /dev/null +++ b/src/TracedRNumber.jl @@ -0,0 +1,242 @@ +mutable struct TracedRNumber{T} <: RNumber{T} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + + function TracedRNumber{T}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {T} + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == () + end + return new{T}(paths, mlir_data) + end +end + +Base.eltype(::Type{TracedRNumber{T}}) where {T} = T + +Base.getindex(a::TracedRNumber{T}) where {T} = a + +Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) +Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) + +Base.eps(::Type{TracedRNumber{T}}) where {T} = promote_to(TracedRNumber{T}, eps(T)) + +function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} + return promote_to(TracedRNumber{T}, T(x)) +end + +function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} + return print(io, "TracedRNumber{", T, "}(", X.paths, ")") +end + +Base.only(A::TracedRNumber{T}) where {T} = A + +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} +end + +# Bool has special promotion rules in Base +function Base.promote_rule(::Type{Bool}, ::Type{TracedRNumber{T}}) where {T} + return TracedRNumber{T} +end + +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{Bool}) where {T} + return TracedRNumber{T} +end + +function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} +end + +function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} + return promote_to(TracedRNumber{T}, x) +end + +function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} + if isa(rhs, TracedRNumber) + rhs isa TracedRNumber{T} && return rhs + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.convert( + rhs.mlir_data; result=mlir_type(TracedRNumber{T}) + ), + 1, + ), + ) + end + if isa(rhs, TracedRArray{<:Any,0}) + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.convert( + rhs.mlir_data; result=mlir_type(TracedRNumber{T}) + ), + 1, + ), + ) + end + if isa(rhs, Number) + attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T})) + return TracedRNumber{T}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ) + end + T0 = eltype(rhs) + attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) + return promote_to( + TracedRNumber{T}, + TracedRNumber{T0}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ), + ) +end + +promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) + +for (jlop, hloop) in ( + (:(Base.min), :minimum), + (:(Base.max), :maximum), + (:(Base.:+), :add), + (:(Base.:-), :subtract), + (:(Base.:*), :multiply), + (:(Base.:/), :divide), + (:(Base.:^), :power), +) + @eval function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end +end + +for (jlop, hloop, hlocomp) in ( + (:(Base.:(==)), :compare, "EQ"), + (:(Base.:(!=)), :compare, "NE"), + (:(Base.:(>=)), :compare, "GE"), + (:(Base.:(>)), :compare, "GT"), + (:(Base.:(<=)), :compare, "LE"), + (:(Base.:(<)), :compare, "LT"), +) + @eval begin + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return TracedRNumber{Bool}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp + ), + ), + 1, + ), + ) + end + + function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} + return $(jlop)(lhs, promote_to(lhs, rhs)) + end + + function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T} + return $(jlop)(promote_to(rhs, lhs), rhs) + end + + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2}) + ) where {T1,T2} + commonTy = TracedRNumber{Base.promote_type(T1, T2)} + lhs = promote_to(commonTy, lhs) + rhs = promote_to(commonTy, rhs) + return $(jlop)(lhs, rhs) + end + end +end + +function Base.ifelse( + @nospecialize(pred::TracedRNumber{Bool}), + @nospecialize(x::TracedRNumber{T1}), + @nospecialize(y::TracedRNumber{T2}) +) where {T1,T2} + return TracedRNumber{promote_type(T1, T2)}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 + ), + ) +end + +function Base.literal_pow( + ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} +) where {T,P} + return Base.literal_pow(^, x, Val(P)) +end + +for (jlop, hloop) in ( + (:(Base.abs), :abs), + (:(Base.:-), :negate), + (:(Base.sin), :sine), + (:(Base.cos), :cosine), + (:(Base.tanh), :tanh), + (:(Base.FastMath.tanh_fast), :tanh), + (:(Base.exp), :exponential), + (:(Base.FastMath.exp_fast), :exponential), + (:(Base.log), :log), + (:(Base.sqrt), :sqrt), +) + @eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} + return TracedRNumber{T}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) + ) + end +end + +# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined +Base.abs2(x::TracedRNumber{<:Real}) = x^2 + +Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) + +struct TypeCast{T<:ReactantPrimitive} <: Function end + +(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) + +Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x) + +# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays +Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) +function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} + return Base.typed_vcat(T, map(Base.Fix2(broadcast_to_size, (1,)), x)...) +end + +Base.hcat(x::TracedRNumber...) = Base.typed_hcat(Base.promote_eltypeof(x...), x...) +function Base.typed_hcat(::Type{T}, x::TracedRNumber...) where {T} + return Base.typed_hcat(T, map(Base.Fix2(broadcast_to_size, (1, 1)), x)...) +end + +function Base.hvcat(rows::Tuple{Vararg{Int}}, xs::TracedRNumber...) + return Base.typed_hvcat(Base.promote_eltypeof(xs...), rows, xs...) +end +function Base.typed_hvcat( + ::Type{T}, rows::Tuple{Vararg{Int}}, xs::TracedRNumber... +) where {T} + xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + return Base.typed_hvcat(T, rows, xs...) +end + +function Base.hvncat(dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber...) + return Base.typed_hvncat(Base.promote_eltypeof(xs...), dims, row_first, xs...) +end +function Base.typed_hvncat( + ::Type{T}, dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber... +) where {T} + xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + return Base.typed_hvncat(T, dims, row_first, xs...) +end diff --git a/src/Tracing.jl b/src/Tracing.jl index b3233a6447..dd6c4e77cf 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -16,6 +16,7 @@ for T in ( Integer, AbstractString, RArray, + RNumber, ) @eval function traced_type(::Type{T}, seen, mode) where {T<:$T} return T @@ -182,7 +183,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode end end -function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedRArray,mode} +function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,mode} if mode == ConcreteToTraced throw("TracedRArray $T cannot be traced") elseif mode == TracedToConcrete @@ -202,7 +203,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:XLAArray} end function traced_type(::Type{A}, seen::ST, ::Val{mode}) where {T,N,A<:Array{T,N},ST,mode} - if mode == ArrayToConcrete && T <: AbstractFloat + if mode == ArrayToConcrete && T <: ReactantPrimitive return ConcreteRArray{T,N} else return Array{traced_type(T, seen, Val(mode)),N} @@ -330,9 +331,9 @@ function make_tracer( return seen[prev] end res = if toscalar - TracedRArray{T,0}((path,), nothing, ()) - elseif !isnothing(tobatch) - TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + TracedRNumber{T}((path,), nothing) + elseif tobatch !== nothing + error("This should not happen...") else TracedRArray{T,N}((path,), prev.mlir_data, size(prev)) end @@ -352,6 +353,52 @@ function make_tracer( throw("Cannot Unknown trace mode $mode") end +function make_tracer( + seen, + @nospecialize(prev::TracedRNumber{T}), + @nospecialize(path), + mode; + tobatch=nothing, + toscalar=false, + kwargs..., +) where {T} + if mode == ConcreteToTraced + throw("Cannot trace existing trace type") + end + if mode == TracedTrack + prev.paths = (prev.paths..., path) + if !haskey(seen, prev) + return seen[prev] = prev + end + return prev + end + if mode == TracedSetPath + if haskey(seen, prev) + return seen[prev] + end + res = if toscalar + TracedRNumber{T}((path,), nothing) + elseif tobatch !== nothing + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + else + TracedRNumber{T}((path,), prev.mlir_data) + end + seen[prev] = res + return res + end + + if mode == TracedToConcrete + if haskey(seen, prev) + return seen[prev]::ConcreteRArray{T,0} + end + res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev)) + seen[prev] = res + return res + end + + throw("Cannot Unknown trace mode $mode") +end + function make_tracer( seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... ) where {RT<:AbstractFloat} @@ -380,7 +427,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat,Integer} + if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), (), Val(mode)) diff --git a/src/XLA.jl b/src/XLA.jl index 9e77dac6df..556b6ff33a 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -227,7 +227,9 @@ end @inline primitive_type(::Type{Float16}) = 10 @inline primitive_type(::Type{Float32}) = 11 -# @inline primitive_type(::Type{BFloat16}) = 16 +@static if isdefined(Core, :BFloat16) + @inline primitive_type(::Type{Core.BFloat16}) = 16 +end @inline primitive_type(::Type{Float64}) = 12 diff --git a/src/utils.jl b/src/utils.jl index 4e137a3889..b379366bc5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,11 +2,17 @@ function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) end +mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) + function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} @assert length(shape) == N return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) end +function mlir_type(::Type{<:RNumber{T}}) where {T} + return MLIR.IR.TensorType((), MLIR.IR.Type(T)) +end + function transpose_ty(mlirty) return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) end @@ -23,7 +29,10 @@ end function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=false) if sizeof(typeof(f)) != 0 || f isa BroadcastFunction - return (true, make_mlir_fn(apply, (f, args...), kwargs, name, concretein)[2:end]...) + return ( + true, + make_mlir_fn(apply, (f, args...), kwargs, name, concretein; toscalar)[2:end]..., + ) end N = length(args) @@ -38,11 +47,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_args = TracedRArray[] + linear_args = TracedType[] for (k, v) in seen_args - if !(v isa TracedRArray) - continue - end + v isa TracedType || continue push!(linear_args, v) end @@ -121,13 +128,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_results = TracedRArray[] + linear_results = TracedType[] for (k, v) in seen_results - if !(v isa TracedRArray) - continue - end - + v isa TracedType || continue push!(linear_results, v) end diff --git a/test/basic.jl b/test/basic.jl index 368f22d728..56d4a121dc 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -65,6 +65,8 @@ end sumexp(x) = sum(exp, x) +sum_compare(x) = sum(x) > 0 + @testset "Basic mapreduce" begin x = rand(Float32, 10) a = Reactant.ConcreteRArray(x) @@ -74,6 +76,12 @@ sumexp(x) = sum(exp, x) f_res = f(a) @test f_res ≈ r_res + + # Ensure we are tracing as scalars. Else this will fail due to > not being defined on + # arrays + f = @compile sum_compare(a) + # We need to use [] to unwrap the scalar. We will fix this in the future. + @test f(a)[] == sum_compare(x) end function mysoftmax!(x) @@ -210,7 +218,90 @@ end end @testset "concatenation" begin - @testset "$(ndims(x))-dim" for x in [ + @testset "Number" begin + x = fill(true) + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + test_vcat(x) = begin + x = x[] # unwrap scalar + [x; x; x] + end + f = @compile test_vcat(x_concrete) + @test f(x_concrete) == test_vcat(x) + @test eltype(f(x_concrete)) === Bool + + # hcat + test_hcat(x) = begin + x = x[] # unwrap scalar + [x x x] + end + f = @compile test_hcat(x_concrete) + @test f(x_concrete) == test_hcat(x) + @test eltype(f(x_concrete)) === Bool + + # hvcat + test_hvcat(x) = begin + x = x[] # unwrap scalar + [x x x; x x x] + end + f = @compile test_hvcat(x_concrete) + @test f(x_concrete) == test_hvcat(x) + @test eltype(f(x_concrete)) === Bool + + # hvncat + test_hvncat(x) = begin + x = x[] # unwrap scalar + [x x x; x x x;;; x x x; x x x] + end + f = @compile test_hvncat(x_concrete) + @test f(x_concrete) == test_hvncat(x) + @test eltype(f(x_concrete)) === Bool + + # typed_vcat + test_typed_vcat(x) = begin + x = x[] # unwrap scalar + Int[x; x; x] + end + f = @compile test_typed_vcat(x_concrete) + @test f(x_concrete) == test_typed_vcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hcat + test_typed_hcat(x) = begin + x = x[] # unwrap scalar + Int[x x x] + end + f = @compile test_typed_hcat(x_concrete) + @test f(x_concrete) == test_typed_hcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hvcat + test_typed_hvcat(x) = begin + x = x[] # unwrap scalar + Int[x x x; x x x] + end + f = @compile test_typed_hvcat(x_concrete) + @test f(x_concrete) == test_typed_hvcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hvncat + test_typed_hvncat(x) = begin + x = x[] # unwrap scalar + Int[x x x; x x x;;; x x x; x x x] + end + f = @compile test_typed_hvncat(x_concrete) + @test f(x_concrete) == test_typed_hvncat(x) + @test eltype(f(x_concrete)) === Int + end + + @testset "$(ndims(x))-dim Array" for x in [ fill(true), [true, false], [true false], diff --git a/test/bcast.jl b/test/bcast.jl index 3294aeaab9..4e56015894 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -112,3 +112,23 @@ pow(x, n) = x .^ n @test pow_compiled(x_ra) ≈ pow(x, 2) end + +struct CustomBCastFunction{X} + x::X +end + +(f::CustomBCastFunction)(x::Number) = f.x + x + +function custombcast(x) + fn = CustomBCastFunction(3.0) + return fn.(x) +end + +@testset "Broadcasting closures / functors" begin + x = rand(2, 3) + x_ra = Reactant.to_rarray(x) + + custombcast_compiled = @compile custombcast(x_ra) + + @test custombcast_compiled(x_ra) ≈ custombcast(x) +end diff --git a/test/compile.jl b/test/compile.jl index c5944e4c76..85692ad72b 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -7,7 +7,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= @testset "create_result" begin @testset "NamedTuple" begin x = (; a=rand(4, 3)) - x2 = (; a=Reactant.ConcreteRArray(x.a)) + x2 = Reactant.to_rarray(x) f = @compile sum(x2) diff --git a/test/nn/lux.jl b/test/nn/lux.jl index b4a4558783..49fa37f52c 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -1,15 +1,8 @@ using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays -function crossentropy(ŷ, y) - logŷ = log.(ŷ) - result = y .* logŷ - return -sum(result) -end - function loss_function(model, x, y, ps, st) y_hat, _ = model(x, ps, st) - # return CrossEntropyLoss()(y_hat, y) # <-- needs handling of xlogx xlogy from LuxOps - return crossentropy(y_hat, y) + return CrossEntropyLoss()(y_hat, y) end function gradient_loss_function(model, x, y, ps, st)