Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bias_act! #457

Merged
merged 11 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pad_zeros
## Convolution

`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.

`NNlib.conv` supports complex datatypes on CPU and CUDA devices.

!!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true).
Expand Down Expand Up @@ -152,4 +153,5 @@ ctc_loss
logsumexp
NNlib.glu
NNlib.within_gradient
bias_act!
```
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("bias_act.jl")
export bias_act!

include("fold.jl")

include("ctc.jl")
Expand Down
107 changes: 107 additions & 0 deletions src/bias_act.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@

using NNlib: fast_act, tanh_fast
using ChainRulesCore

const RCR = RuleConfig{>:HasReverseMode}

# This just saves typing `only.(only.(` many times:
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))

# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

"""
bias_act!(σ, x, b)

This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh`
with `sigmoid_fast` & `tanh_fast`.
It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`.

When used within a gradient, it will overwrite only when `σ` has
a method of `derivatives_given_output` which does not need the input at all.
Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative
contains only `Ω` (the output) not `x`.

!!! warning
This is not safe to use if `x` is still needed for the gradient
of some other function. Incorrect use will give silently wrong answers.
It is intended mainly for Flux layers, in which the previous operation is
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
"""
bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
_fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug

function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
_fast_broadcast!(fast_act(σ, x), x)
end

function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
x # pass-through
end

function bias_act!(σ::Function, x::AbstractArray, b)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
fast_act(σ, x).(x .+ b) # fallback
end

function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
biasgrad = if eltype(B) !== Bool
# Summing over ndims(x)+1 is a trick to make b_dims type-stable
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
_biasgrad(dx) = reshape(sum(dx; dims), size(b))
else
Returns(NoTangent())
end

# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat}
function bias_act!_fastback(Δ)
# Tempting to overwrite x again, but only safe if you call pullback at most once,
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ)
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
end
return Ω, bias_act!_fastback

# # Slower path: can't overwrite x, but can use derivatives_given_output
# # This case is WRONG and tests fail, but not sure why
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
# Ω2 = fast_act(σ, x).(x) .+ b
# @show σ b
# function bias_act!_back2(Δ)
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
# return (NoTangent(), NoTangent(), dx, biasgrad(dx))
# end
# return Ω2, bias_act!_back2

# Fallback path: let AD handle the broadcast
else
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b))
@inline function bias_act!_slowback(Δ)
_, _, dx = back(Δ)
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
end
return Ω3, bias_act!_slowback
end
end

# Two easy cases with identity
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B}
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
biasgrad(dx) = reshape(sum(dx; dims), size(b))
function bias_act!_idback(Δ)
dx = unthunk(Δ)
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
end
return bias_act!(identity, x, b), bias_act!_idback
end
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())
return x, bias_act!_trivial
end

21 changes: 0 additions & 21 deletions src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,6 @@ end
# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.
# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402

"""
_fast_broadcast!(f, x, y, z...)

This does `x .= f.(x, y, z...)`, but works around
an issue with broadcasting that prevents SIMD in such cases.
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.

Not intended for general use. Does not check sizes!
"""
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return x
end
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
# CUDA does not suffer from this bug
broadcast!(f, x, x, yz...)
end


"""
_rng_from_array(x)
Expand Down
64 changes: 60 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), No
"""
safe_div(x, y)

Safely divide `x` by `y`. If `y` is zero, return `x` directly.
Returns `x/y` unless `y==0`, in which case it just returns `x`.
(Used internally by `scatter`.)
"""
safe_div(x, y) = ifelse(iszero(y), x, x/y)

"""
maximum_dims(dims)

Return the maximum value for each dimension. An array of dimensions `dims` is accepted.
The maximum of each dimension in the element is computed.
Given an array of `CartesianIndex{N}` or `NTuple{N,Int}`,
returns a tuple containing the maximum of all the 1st entries,
all the 2nd entries, and so on up to `N`.

Given an array of integers, returns `(maximum(dims),)`.

(These arguments are what [`scatter`](@ref NNlib.scatter) understands.)
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)
Expand Down Expand Up @@ -105,4 +111,54 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N
return reverse_indices!(rev, idx)
end

unsqueeze(x) = reshape(x, 1, size(x)...)
unsqueeze(x) = reshape(x, 1, size(x)...)


"""
_fast_broadcast!(f, x, y, z...)

This does `x .= f.(x, y, z...)`, but works around
an issue with broadcasting that prevents SIMD in such cases.
Can perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.

Has an `rrule` to avoid mutation within derivatives.

!!! warning
Not intended for general use.
Uses `@inbounds` but does not check sizes!
Assumes that `f` has no derivative!
"""
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return x
end
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
# CUDA does not suffer from this bug
broadcast!(f, x, x, yz...)
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...) where {F<:Function}
rrule_via_ad(cfg, broadcast, f, x, ys...)
end

# Could get this from Compat.jl instead
# https://github.com/JuliaLang/julia/pull/39794
if VERSION < v"1.7.0-DEV.793"
struct Returns{V} <: Function
value::V
Returns{V}(value) where {V} = new{V}(value)
Returns(value) = new{Core.Typeof(value)}(value)
end

(obj::Returns)(args...; kw...) = obj.value
function Base.show(io::IO, obj::Returns)
show(io, typeof(obj))
print(io, "(")
show(io, obj.value)
print(io, ")")
end
end

114 changes: 114 additions & 0 deletions test/bias_act.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using NNlib, Zygote, ChainRulesCore, Test
using Zygote: ForwardDiff

ACTIVATION_FUNCTIONS =
[@eval($a) for a in NNlib.ACTIVATIONS]

@testset "bias_act!" begin
x = randn(3,4)
b = randn(3)
@test @inferred(bias_act!(identity, x, false)) === x # pass-through
@test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b)
@test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b)
@test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b)
@test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x)

# Check that it does overwrite:
x32 = rand(Float32, 3, 4); x32copy = copy(x32)
@test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b)
@test x32 ≈ cbrt.(x32copy .+ b)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
@test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy)
@test x32 ≈ tanh.(x32copy)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)
@test y ≈ x32 ≈ relu.(x32copy .+ b)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)
@test y ≈ x32 ≈ relu.(x32copy)

# Check that it doesn't try to overwrite non-float arrays:
xint = rand(-3:3, 3, 4)
bint = rand(-2:2, 3)
@test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint
@test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint)
@test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint)

# Reject bias===true so that Bool means one thing:
@test_throws Exception bias_act!(identity, rand(3), true)
@test_throws Exception bias_act!(cbrt, rand(3), true)
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true)

@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
ACTIVATION_FUNCTIONS,
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)])
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about.
fun == rrelu && continue # this one is randomised!
fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below

@test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b)
@test bias_act!(fun, copy(x), false) ≈ fun.(x)

gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps())
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps())
if !(gx ≈ gxplus ≈ gxminus)
@warn "skipping gradient tests due to discontinuity" fun x b
continue
end
@test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]

gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
if !(gx2 ≈ gx2plus ≈ gx2minus)
@warn "skipping gradient tests due to discontinuity" fun x
continue
Comment on lines +64 to +69
Copy link
Member Author

@mcabbott mcabbott Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This slightly elaborate thing is avoiding my best guess as to why there were failures on CI: hardsigmoid has discontinuities, and if x hits them, the two gradients may not agree.

But it doesn't seem to work:

  gradient with hardσ: Test Failed at /home/runner/work/NNlib.jl/NNlib.jl/test/bias_act.jl:73
  Expression: gb ≈ (Zygote.gradient((b->(sum(bias_act!(fun, copy(x), b));)), b))[1]
   Evaluated: [0.5, 0.6666666666666666, 0.6666666666666666] ≈ [1.5000000000000002, 0.6666666666666666, 0.6666666666666666]

end
@test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]

gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)
@test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1]

@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,)
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,)
end

@testset "gradient for fast_broadcast!" begin
# Gradient definition is just to disable mutation inside 2nd order AD
gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)
@test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1]

# relu should take the fast path
g2 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
end
@test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
end
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
# [5] (::typeof(∂(accum_global)))(Δ::Nothing)
@test g2 ≈ Zygote.gradient(x, b) do x, b
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1])
end[1]

g3 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
end
@test g3 ≈ Zygote.gradient(x, b) do x, b
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
end[1]

# Anon function sure to take the generic path
g4 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
end
@test g4 ≈ Zygote.gradient(x, b) do x, b
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
end[1]
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ end

@testset "Activation Functions" begin
include("activations.jl")
include("bias_act.jl")
end

@testset "Attention" begin
Expand Down