-
-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bias_act!
#457
Merged
Merged
Add bias_act!
#457
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
554f339
sometimes-in-place bias_act
mcabbott 13ab2e7
update after dropout PR
mcabbott f08fc0b
add to docs
mcabbott a9136e7
also fix two unrelated docstring which just told you what the functio…
mcabbott 83882de
tidy & un-comment
mcabbott 419725c
comment out 2nd path again
mcabbott dbf39d4
add Returns for 1.6
mcabbott 791531a
upgrade tests
mcabbott 7b04b15
more tests
mcabbott c9a5722
skip hardσ tests
mcabbott cd99b77
Update test/bias_act.jl
mcabbott File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
|
||
using NNlib: fast_act, tanh_fast | ||
using ChainRulesCore | ||
|
||
const RCR = RuleConfig{>:HasReverseMode} | ||
|
||
# This just saves typing `only.(only.(` many times: | ||
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) | ||
|
||
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` | ||
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. | ||
struct NotaNumber <: Real end | ||
|
||
""" | ||
bias_act!(σ, x, b) | ||
|
||
This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh` | ||
with `sigmoid_fast` & `tanh_fast`. | ||
It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`. | ||
|
||
When used within a gradient, it will overwrite only when `σ` has | ||
a method of `derivatives_given_output` which does not need the input at all. | ||
Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative | ||
contains only `Ω` (the output) not `x`. | ||
|
||
!!! warning | ||
This is not safe to use if `x` is still needed for the gradient | ||
of some other function. Incorrect use will give silently wrong answers. | ||
It is intended mainly for Flux layers, in which the previous operation is | ||
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer. | ||
""" | ||
bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = | ||
_fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug | ||
|
||
function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) | ||
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
_fast_broadcast!(fast_act(σ, x), x) | ||
end | ||
|
||
function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) | ||
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
x # pass-through | ||
end | ||
|
||
function bias_act!(σ::Function, x::AbstractArray, b) | ||
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
fast_act(σ, x).(x .+ b) # fallback | ||
end | ||
|
||
function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} | ||
biasgrad = if eltype(B) !== Bool | ||
# Summing over ndims(x)+1 is a trick to make b_dims type-stable | ||
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) | ||
_biasgrad(dx) = reshape(sum(dx; dims), size(b)) | ||
else | ||
Returns(NoTangent()) | ||
end | ||
|
||
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ | ||
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) | ||
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} | ||
function bias_act!_fastback(Δ) | ||
# Tempting to overwrite x again, but only safe if you call pullback at most once, | ||
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340 | ||
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 | ||
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ) | ||
return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
end | ||
return Ω, bias_act!_fastback | ||
|
||
# # Slower path: can't overwrite x, but can use derivatives_given_output | ||
# # This case is WRONG and tests fail, but not sure why | ||
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) | ||
# Ω2 = fast_act(σ, x).(x) .+ b | ||
# @show σ b | ||
# function bias_act!_back2(Δ) | ||
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ) | ||
# return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
# end | ||
# return Ω2, bias_act!_back2 | ||
|
||
# Fallback path: let AD handle the broadcast | ||
else | ||
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b)) | ||
@inline function bias_act!_slowback(Δ) | ||
_, _, dx = back(Δ) | ||
return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
end | ||
return Ω3, bias_act!_slowback | ||
end | ||
end | ||
|
||
# Two easy cases with identity | ||
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B} | ||
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) | ||
biasgrad(dx) = reshape(sum(dx; dims), size(b)) | ||
function bias_act!_idback(Δ) | ||
dx = unthunk(Δ) | ||
return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
end | ||
return bias_act!(identity, x, b), bias_act!_idback | ||
end | ||
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N} | ||
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent()) | ||
return x, bias_act!_trivial | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
using NNlib, Zygote, ChainRulesCore, Test | ||
using Zygote: ForwardDiff | ||
|
||
ACTIVATION_FUNCTIONS = | ||
[@eval($a) for a in NNlib.ACTIVATIONS] | ||
|
||
@testset "bias_act!" begin | ||
x = randn(3,4) | ||
b = randn(3) | ||
@test @inferred(bias_act!(identity, x, false)) === x # pass-through | ||
@test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b) | ||
@test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b) | ||
@test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b) | ||
@test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) | ||
|
||
# Check that it does overwrite: | ||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) | ||
@test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) | ||
@test x32 ≈ cbrt.(x32copy .+ b) | ||
|
||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias | ||
@test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) | ||
@test x32 ≈ tanh.(x32copy) | ||
|
||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule | ||
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b) | ||
@test y ≈ x32 ≈ relu.(x32copy .+ b) | ||
|
||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias | ||
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false) | ||
@test y ≈ x32 ≈ relu.(x32copy) | ||
|
||
# Check that it doesn't try to overwrite non-float arrays: | ||
xint = rand(-3:3, 3, 4) | ||
bint = rand(-2:2, 3) | ||
@test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint | ||
@test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint) | ||
@test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint) | ||
|
||
# Reject bias===true so that Bool means one thing: | ||
@test_throws Exception bias_act!(identity, rand(3), true) | ||
@test_throws Exception bias_act!(cbrt, rand(3), true) | ||
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true) | ||
|
||
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt], | ||
ACTIVATION_FUNCTIONS, | ||
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)]) | ||
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about. | ||
fun == rrelu && continue # this one is randomised! | ||
fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below | ||
|
||
@test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b) | ||
@test bias_act!(fun, copy(x), false) ≈ fun.(x) | ||
|
||
gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x) | ||
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps()) | ||
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps()) | ||
if !(gx ≈ gxplus ≈ gxminus) | ||
@warn "skipping gradient tests due to discontinuity" fun x b | ||
continue | ||
end | ||
@test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1] | ||
|
||
gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x) | ||
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) | ||
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) | ||
if !(gx2 ≈ gx2plus ≈ gx2minus) | ||
@warn "skipping gradient tests due to discontinuity" fun x | ||
continue | ||
end | ||
@test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1] | ||
|
||
gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b) | ||
@test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1] | ||
|
||
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,) | ||
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,) | ||
end | ||
|
||
@testset "gradient for fast_broadcast!" begin | ||
# Gradient definition is just to disable mutation inside 2nd order AD | ||
gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x) | ||
@test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1] | ||
|
||
# relu should take the fast path | ||
g2 = ForwardDiff.gradient(x) do x | ||
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) | ||
end | ||
@test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error | ||
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) | ||
end | ||
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)). | ||
# [5] (::typeof(∂(accum_global)))(Δ::Nothing) | ||
@test g2 ≈ Zygote.gradient(x, b) do x, b | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1]) | ||
end[1] | ||
|
||
g3 = ForwardDiff.gradient(x) do x | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) | ||
end | ||
@test g3 ≈ Zygote.gradient(x, b) do x, b | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) | ||
end[1] | ||
|
||
# Anon function sure to take the generic path | ||
g4 = ForwardDiff.gradient(x) do x | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) | ||
end | ||
@test g4 ≈ Zygote.gradient(x, b) do x, b | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) | ||
end[1] | ||
end | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This slightly elaborate thing is avoiding my best guess as to why there were failures on CI: hardsigmoid has discontinuities, and if
x
hits them, the two gradients may not agree.But it doesn't seem to work: