diff --git a/src/Optimisers.jl b/src/Optimisers.jl index cb3afa97..9c3b0be0 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -2,6 +2,8 @@ module Optimisers using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk using LinearAlgebra +using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk include("interface.jl") export AbstractRule @@ -16,6 +18,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion +include("deprecations.jl") + ### ### one-array functions ### diff --git a/src/deprecations.jl b/src/deprecations.jl new file mode 100644 index 00000000..94df7bef --- /dev/null +++ b/src/deprecations.jl @@ -0,0 +1,12 @@ +# To be removed in Optimisers v0.3 + +@deprecate iswriteable maywrite false # remove when releasing Optimisers@0.3 + +@deprecate ADAM Adam +@deprecate NADAM NAdam +@deprecate ADAMW AdamW +@deprecate RADAM RAdam +@deprecate OADAM OAdam +@deprecate ADAGrad AdaGrad +@deprecate ADADelta AdaDelta + diff --git a/src/destructure.jl b/src/destructure.jl index 3b21d918..a0c6d2f6 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -1,5 +1,4 @@ -using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk const NoT = NoTangent() """ diff --git a/src/interface.jl b/src/interface.jl index 04ead1f3..936d8c3d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,7 +1,7 @@ -using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent base(dx::Tangent) = backing(canonicalize(dx)) base(dx) = dx + const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} abstract type AbstractRule end @@ -15,6 +15,7 @@ mutable struct Leaf{R,S} # mutable so that its identity encodes parameter shari state::S frozen::Bool # ... and to allow freeze! to act on this. end + Leaf(rule, state; frozen::Bool = false) = Leaf(rule, state, frozen) @functor Leaf @@ -23,9 +24,9 @@ Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) function setup(rule::AbstractRule, model) cache = IdDict() - tree = _setup(rule, model; cache) + state = _setup(rule, model; cache) isempty(cache) && @warn "setup found no trainable parameters in this model" - tree + state end # _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. @@ -56,38 +57,38 @@ end ### update ### -function update(tree, model, grad, higher...) - t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf +function update(state, model, grad, higher...) + t′ = fmap(copy, state; exclude = maywrite) # walks inside Leaf x′ = fmap(copy, model; exclude = maywrite) update!(t′, x′, grad, higher...) end -function update!(tree, model, grad, higher...) +function update!(state, model, grad, higher...) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: grads = IdDict{Leaf, Any}() - _grads!(grads, tree, model, grad, higher...) - # Second walk is to update the model. The params cache indexed by (tree,x), + _grads!(grads, state, model, grad, higher...) + # Second walk is to update the model. The params cache indexed by (state,x), # so that identified Leafs can tie isbits parameters, but setup won't do that for you: - newmodel = _update!(tree, model; grads, params = IdDict()) - tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree. + newmodel = _update!(state, model; grads, params = IdDict()) + state, newmodel # Note that state is guaranteed to be updated. Also that it's not necc a tree. end -function _update!(tree, x; grads, params) - haskey(params, (tree,x)) && return params[(tree,x)] - isbits(tree) && return x # means () is not cached, and also (((),),) +function _update!(state, x; grads, params) + haskey(params, (state, x)) && return params[(state, x)] + isbits(state) && return x # means () is not cached, and also (((),),) x′, re = functor(x) - x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)) + x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), state, x′)) if ismutable(x′′) - params[(tree,x)] = x′′ + params[(state, x)] = x′′ else # no ties to preserve between immutable structs, right? x′′ end end function _update!(ℓ::Leaf, x; grads, params) - haskey(params, (ℓ,x)) && return params[(ℓ,x)] + haskey(params, (ℓ, x)) && return params[(ℓ, x)] ℓ.frozen && return x - params[(ℓ,x)] = if haskey(grads, ℓ) + params[(ℓ, x)] = if haskey(grads, ℓ) ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) subtract!(x, x̄′) else @@ -98,18 +99,21 @@ end subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) _grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing + function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...) x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s)) dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. nothing end + _grads!(dict::IdDict, t, x, ::Zero...) = nothing -function _grads!(dict::IdDict, tree, x, x̄s...) + +function _grads!(dict::IdDict, state, x, x̄s...) # The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from - # functor(typeof(tree), base(x̄)), for things like Transpose + # functor(typeof(state), base(x̄)), for things like Transpose x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) x′, _ = functor(typeof(x), x) - valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) + valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), state, x′, x̄s′...) end # default all rules to first order calls @@ -142,8 +146,6 @@ For now, simply `x isa DenseArray` allowing `Array`, `CuArray`, etc. maywrite(::DenseArray) = true # see https://github.com/FluxML/Optimisers.jl/issues/99 for discussion maywrite(_) = false -@deprecate iswriteable maywrite false # remove when releasing Optimisers@0.3 - """ trainable(x::Layer) -> NamedTuple @@ -175,6 +177,7 @@ end valuemap(f, x...) = map(f, x...) valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) + valueforeach(f, x...) = foreach(f, x...) valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) f(v, (get(y, k, nothing) for y in ys)...) @@ -216,8 +219,11 @@ macro lazy(ex) end function lazy end + Broadcast.broadcasted(::typeof(lazy), x) = Lazy(x) + struct Lazy{T}; bc::T; end + Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc) onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x) diff --git a/src/rules.jl b/src/rules.jl index 57f2d752..898f563c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,11 +1,3 @@ -@deprecate ADAM Adam -@deprecate NADAM NAdam -@deprecate ADAMW AdamW -@deprecate RADAM RAdam -@deprecate OADAM OAdam -@deprecate ADAGrad AdaGrad -@deprecate ADADelta AdaDelta - """ Descent(η = 1f-1)