Skip to content

Commit 4331dc8

Browse files
committed
freezingdocs
1 parent 9061b79 commit 4331dc8

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

Diff for: docs/src/tutorials/misc-model-tweaking.md

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Choosing differentiable/gpu parts of the model
2+
!!! note
3+
This tutorial features somewhat disconnected topics about customizing your
4+
models even further. It is advised to be familiar with
5+
[`Flux.@layer`](@ref), [`Flux.@functor`](@ref), [`freeze!`](@ref
6+
Flux.freeze!) and other basics of Flux.
7+
8+
Flux provides several ways of freezing, excluding from backprop entirely and
9+
marking custom struct fields not to be moved to the GPU
10+
([Functors.@functor](@ref)) hence excluded from being trained. The following
11+
subsections should make it clear which one suits your needs the best.
12+
13+
## On-the-fly freezing per model instance
14+
Perhaps you'd like to freeze some of the weights of the model (even at
15+
mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`.
16+
17+
```julia
18+
m = Chain(
19+
Dense(784 => 64, relu), # freeze this one
20+
Dense(64 => 64, relu),
21+
Dense(32 => 10)
22+
)
23+
opt_state = Flux.setup(Momentum(), m);
24+
25+
# Freeze some layers right away
26+
Flux.freeze!(opt_state.layers[1])
27+
28+
for data in train_set
29+
input, label = data
30+
31+
# Some params could be frozen during the training:
32+
Flux.freeze!(opt_state.layers[2])
33+
34+
grads = Flux.gradient(m) do m
35+
result = m(input)
36+
loss(result, label)
37+
end
38+
Flux.update!(opt_state, m, grads[1])
39+
40+
# Optionally unfreeze the params later
41+
Flux.thaw!(opt_state.layers[1])
42+
end
43+
```
44+
45+
## Static freezing per model definition
46+
Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params
47+
still need to reside on the GPU (these params are still needed in the forward
48+
and/or backward pass).
49+
```julia
50+
struct MaskedLayer{T}
51+
chain::Chain
52+
mask::T
53+
end
54+
Flux.@layer MyLayer trainable=(chain,)
55+
# mask field will not be updated in the training loop
56+
57+
function (m::MaskedLayer)(x)
58+
# mask field will still move to to gpu for efficient operations:
59+
return m.chain(x) + x + m.mask
60+
end
61+
62+
model = MaskedLayer(...) # this model will not have the `mask` field trained
63+
```
64+
Note how this method permanently sets some model fields to be excluded from
65+
training without on-the-fly changing.
66+
67+
## Excluding from model definition
68+
Sometimes some parameters aren't just "not trainable" but they shouldn't even
69+
transfer to the GPU (or be part of the functor). All scalar fields are like this
70+
by default, so things like learning rate multipliers are not trainable nor
71+
transferred to the GPU by default.
72+
```julia
73+
struct CustomLayer{T, F}
74+
chain::Chain
75+
activation_results::Vector{F}
76+
lr_multiplier::Float32
77+
end
78+
Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results`
79+
80+
function (m::CustomLayer)(x)
81+
result = m.chain(x) + x
82+
83+
# `activation_results` are not part of the GPU loop, hence we could do
84+
# things like `push!`
85+
push!(m.activation_results, mean(result))
86+
return result
87+
end
88+
```
89+
See more about this in [`Flux.@functor`](@ref)
90+
91+
92+
## Freezing Layer Parameters (deprecated)
93+
94+
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.
95+
96+
!!! compat "Flux ≤ 0.14"
97+
The mechanism described here is for Flux's old "implicit" training style.
98+
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.
99+
100+
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
101+
this using the slicing features `Chain` provides:
102+
103+
```julia
104+
m = Chain(
105+
Dense(784 => 64, relu),
106+
Dense(64 => 64, relu),
107+
Dense(32 => 10)
108+
);
109+
110+
ps = Flux.params(m[3:end])
111+
```
112+
113+
The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.
114+
115+
During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.
116+
117+
`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:
118+
119+
```julia
120+
Flux.params(m[1], m[3:end])
121+
```
122+
123+
Sometimes, a more fine-tuned control is needed.
124+
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
125+
by simply deleting it from `ps`:
126+
127+
```julia
128+
ps = Flux.params(m)
129+
delete!(ps, m[2].bias)
130+
```
131+

0 commit comments

Comments
 (0)