Skip to content

Commit abc6a9e

Browse files
committed
feat: support logsoftmax
1 parent 6aab7f7 commit abc6a9e

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where
3232
return out ./= tmp
3333
end
3434

35+
function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
36+
max_ = NNlib.fast_maximum(x; dims)
37+
# if all(isfinite, max_)
38+
@fastmath out .= x .- max_
39+
# else
40+
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
41+
# @. out = ifelse(
42+
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
43+
# )
44+
# end
45+
@fastmath log_ = log.(sum(exp, out; dims))
46+
return out .-= log_
47+
end
48+
3549
function NNlib.conv(
3650
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
3751
) where {T,N}

src/TracedRNumber.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ end
193193
# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined
194194
Base.abs2(x::TracedRNumber{<:Real}) = x^2
195195

196+
Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T))
197+
196198
struct TypeCast{T<:ReactantPrimitives} <: Function end
197199

198200
(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)

0 commit comments

Comments
 (0)