diff --git a/Project.toml b/Project.toml index f0422bc486..b4c3dc6d75 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ ChainRulesCore = "1.12" Functors = "0.3, 0.4" MLUtils = "0.2, 0.3.1, 0.4" MacroTools = "0.5" -NNlib = "0.8.9" +NNlib = "0.8.14" NNlibCUDA = "0.2.4" OneHotArrays = "0.1, 0.2" Optimisers = "0.2.12" diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9524c0c284..1ffe742ed4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -167,12 +167,12 @@ end @functor Dense -function (a::Dense)(x::AbstractVecOrMat) - σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc - return σ.(a.weight * x .+ a.bias) -end +(a::Dense)(x::AbstractVecOrMat) = bias_act!(a.σ, a.weight * x, a.bias) -(a::Dense)(x::AbstractArray) = +(a::Dense{typeof(identity), <:AbstractMatrix, <:AbstractVector})(x::AbstractVecOrMat) = + muladd(a.weight, x, a.bias) # fast path, fuse addition + +(a::Dense)(x::AbstractArray) = reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) function Base.show(io::IO, l::Dense) @@ -194,7 +194,7 @@ Create an element-wise layer, whose forward pass is given by: y = σ.(scale .* x .+ bias) This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref). - + The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`, with `init=ones32` by default. You may specify the function `init`, turn off trainable bias with `bias=false`, or provide the array(s) explicitly. @@ -434,7 +434,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) Z = reshape(Wyx, (d_z, :)) # @einsum out[o,s] := σ(Z[o,i] + b[o]) - σ.(Z .+ b) + bias_act!(σ, Z, b) end (a::Bilinear)(x::AbstractVecOrMat) = a(x, x) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 003395c15d..70585f2a56 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -163,7 +163,7 @@ end function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groups = 1, bias = true) where N - + weight = convfilter(k, ch; init, groups) Conv(weight, bias, σ; stride, pad, dilation, groups) end @@ -195,9 +195,9 @@ conv_dims(c::Conv, x::AbstractArray) = ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) - σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) - σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c)) + y = conv(x, c.weight, cdims) + bias_act!(c.σ, y, conv_reshape_bias(c)) end _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups @@ -328,9 +328,9 @@ end ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) - σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) - σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c)) + y = ∇conv_data(x, c.weight, cdims) + bias_act!(c.σ, y, conv_reshape_bias(c)) end function Base.show(io::IO, l::ConvTranspose) @@ -466,9 +466,9 @@ crosscor_dims(c::CrossCor, x::AbstractArray) = ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) - σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) - σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c)) + y = crosscor(x, c.weight, cdims) + bias_act!(c.σ, y, conv_reshape_bias(c)) end function Base.show(io::IO, l::CrossCor) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 7cabc9d5b6..449ab573b2 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -202,8 +202,7 @@ RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_ function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T} Wi, Wh, b = m.Wi, m.Wh, m.b - σ = NNlib.fast_act(m.σ, x) - h = σ.(Wi*x .+ Wh*h .+ b) + h = bias_act!(m.σ, Wi*x, muladd(Wh, h, b)) return h, reshape_cell_output(h, x) end