-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
Description
In Flux, we typically apply @functor
to a type for 2 purposes:
- for recursively traversing structs and mapping leaves, as done by
gpu
- collecting parameters in a
Zygote.Params
for gradient calculation (this is done byFlux.params(model)
).
When we what to distinguish the two behaviors, we useFlux.trainable
for the parameters collection.
This is an
using Flux, Zygote
using Flux: @functor
struct B
b1::Array
b2::Array
end
@functor B
struct A
a1::Array
eps::Number
b::B
end
@functor A
Flux.trainable(a::A) = (a.a1,)
a = A(rand(3),0.1,B(rand(2), rand(2)))
Flux.params(a)
#Params([[0.2755365528802143, 0.7419122552485184, 0.048976872406773175]])
loss(a) = a.eps + sum(a.a1) + sum(a.b.b1)
Now when ones computes the gradient in the implicit form, supposedly only the gradient with respect to
a.a1
should be computed. This appears to not be exactly currently true, every gradient seems to be computed, but at least only the one with respect to a.a1
is exposed
julia> g = gradient(() -> loss(a), Flux.params(a))
Grads(...)
julia> g[a.a1]
3-element Fill{Float64}: entries equal to 1.0
julia> g[a.b.b1]
ERROR: KeyError: key [0.7037661100448469, 0.34941543792301455] not found
Stacktrace:
[1] getindex
@ ./iddict.jl:93 [inlined]
[2] getindex(gs::Zygote.Grads, x::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:279
[3] top-level scope
@ REPL[42]:1
[4] top-level scope
@ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52
julia> g = gradient(() -> loss(a), Flux.params(a)).grads
IdDict{Any, Any} with 2 entries:
[0.275537, 0.741912, 0.0489769] => 3-element Fill{Float64}: entries equal to 1.0
:(Main.a) => (a1 = nothing, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing))
With explicit gradient instead, everything is computed and exposed
julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing)),)
This is bad since we would like to feed this to an update!
function, and also inefficient. How do we tell Zygote to drop some model parts from the gradient computation? I would like the following
julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0),)
I see two possibilities:
- we make gradient
@functor
/trainable
aware - we pass to gradient a keyword argument for the gradient masking
darsnack and lassepe
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Todo