Skip to content

how to selectively take structural gradient #1042

@CarloLucibello

Description

@CarloLucibello

In Flux, we typically apply @functor to a type for 2 purposes:

  1. for recursively traversing structs and mapping leaves, as done by gpu
  2. collecting parameters in a Zygote.Params for gradient calculation (this is done by Flux.params(model)).
    When we what to distinguish the two behaviors, we use Flux.trainable for the parameters collection.
    This is an
using Flux, Zygote
using Flux: @functor

struct B
   b1::Array
   b2::Array
end
@functor B

struct A
   a1::Array
   eps::Number
   b::B
end
@functor A
Flux.trainable(a::A) = (a.a1,)

a = A(rand(3),0.1,B(rand(2), rand(2)))

Flux.params(a) 
#Params([[0.2755365528802143, 0.7419122552485184, 0.048976872406773175]])

loss(a) = a.eps + sum(a.a1) + sum(a.b.b1)

Now when ones computes the gradient in the implicit form, supposedly only the gradient with respect to
a.a1 should be computed. This appears to not be exactly currently true, every gradient seems to be computed, but at least only the one with respect to a.a1 is exposed

julia> g = gradient(() -> loss(a), Flux.params(a))
Grads(...)

julia> g[a.a1]
3-element Fill{Float64}: entries equal to 1.0

julia> g[a.b.b1]
ERROR: KeyError: key [0.7037661100448469, 0.34941543792301455] not found
Stacktrace:
 [1] getindex
   @ ./iddict.jl:93 [inlined]
 [2] getindex(gs::Zygote.Grads, x::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:279
 [3] top-level scope
   @ REPL[42]:1
 [4] top-level scope
   @ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52

julia> g = gradient(() -> loss(a), Flux.params(a)).grads
IdDict{Any, Any} with 2 entries:
  [0.275537, 0.741912, 0.0489769] => 3-element Fill{Float64}: entries equal to 1.0
  :(Main.a)                       => (a1 = nothing, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing))

With explicit gradient instead, everything is computed and exposed

julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing)),)

This is bad since we would like to feed this to an update! function, and also inefficient. How do we tell Zygote to drop some model parts from the gradient computation? I would like the following

julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0),)

I see two possibilities:

  • we make gradient @functor/trainable aware
  • we pass to gradient a keyword argument for the gradient masking

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions