Skip to content

Commit 7e0262c

Browse files
bors[bot]darsnacklorenzoh
authored
Merge #1305
1305: Updates to outdims r=CarloLucibello a=darsnack Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one: - `outdims` for generic functions - Size checking for `outdims(::Dense, isize)` I also added the following additional changes - Removed type signature restrictions on `outdims` for generic functions - Added `outdims` for normalization layers - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building - Right now there is a method error - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible - Updated docs for `outdims` changes ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [x] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: lorenzoh <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
2 parents 2039f3f + 438db24 commit 7e0262c

12 files changed

+321
-168
lines changed

docs/src/models/basics.md

+1-22
Original file line numberDiff line numberDiff line change
@@ -218,25 +218,4 @@ Flux.@functor Affine
218218

219219
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
220220

221-
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
222-
223-
## Utility functions
224-
225-
Flux provides some utility functions to help you generate models in an automated fashion.
226-
227-
`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size.
228-
Currently limited to the following layers:
229-
- `Chain`
230-
- `Dense`
231-
- `Conv`
232-
- `Diagonal`
233-
- `Maxout`
234-
- `ConvTranspose`
235-
- `DepthwiseConv`
236-
- `CrossCor`
237-
- `MaxPool`
238-
- `MeanPool`
239-
240-
```@docs
241-
Flux.outdims
242-
```
221+
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).

docs/src/utilities.md

+48
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,54 @@ Flux.glorot_uniform
3535
Flux.glorot_normal
3636
```
3737

38+
## Model Building
39+
40+
Flux provides some utility functions to help you generate models in an automated fashion.
41+
42+
[`outputsize`](@ref) enables you to calculate the output sizes of layers like [`Conv`](@ref)
43+
when applied to input samples of a given size. This is achieved by passing a "dummy" array into
44+
the model that preserves size information without running any computation.
45+
`outputsize(f, inputsize)` works for all layers (including custom layers) out of the box.
46+
By default, `inputsize` expects the batch dimension,
47+
but you can exclude the batch size with `outputsize(f, inputsize; padbatch=true)` (assuming it to be one).
48+
49+
Using this utility function lets you automate model building for various inputs like so:
50+
```julia
51+
"""
52+
make_model(width, height, inchannels, nclasses;
53+
layer_config = [16, 16, 32, 32, 64, 64])
54+
55+
Create a CNN for a given set of configuration parameters.
56+
57+
# Arguments
58+
- `width`: the input image width
59+
- `height`: the input image height
60+
- `inchannels`: the number of channels in the input image
61+
- `nclasses`: the number of output classes
62+
- `layer_config`: a vector of the number of filters per each conv layer
63+
"""
64+
function make_model(width, height, inchannels, nclasses;
65+
layer_config = [16, 16, 32, 32, 64, 64])
66+
# construct a vector of conv layers programmatically
67+
conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
68+
for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
69+
push!(conv_layers, Conv((3, 3), infilters => outfilters))
70+
end
71+
72+
# compute the output dimensions for the conv layers
73+
# use padbatch=true to set the batch dimension to 1
74+
conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)
75+
76+
# the input dimension to Dense is programatically calculated from
77+
# width, height, and nchannels
78+
return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
79+
end
80+
```
81+
82+
```@docs
83+
Flux.outputsize
84+
```
85+
3886
## Model Abstraction
3987

4088
```@docs

src/Flux.jl

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ include("layers/conv.jl")
4141
include("layers/recurrent.jl")
4242
include("layers/normalise.jl")
4343

44+
include("outputsize.jl")
45+
4446
include("data/Data.jl")
4547

4648
include("losses/Losses.jl")

src/deprecations.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
44
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
55
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
6+
@deprecate outdims(f, inputsize) outputsize(f, inputsize)

src/layers/basic.jl

-42
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,6 @@ function Base.show(io::IO, c::Chain)
4545
print(io, ")")
4646
end
4747

48-
"""
49-
outdims(c::Chain, isize)
50-
51-
Calculate the output dimensions given the input dimensions, `isize`.
52-
53-
```jldoctest
54-
julia> using Flux: outdims
55-
56-
julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32));
57-
58-
julia> outdims(m, (10, 10)) == (6, 6)
59-
true
60-
```
61-
"""
62-
outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize)
63-
6448
# This is a temporary and naive implementation
6549
# it might be replaced in the future for better performance
6650
# see issue https://github.com/FluxML/Flux.jl/issues/702
@@ -158,28 +142,6 @@ end
158142
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
159143
a(T.(x))
160144

161-
"""
162-
outdims(l::Dense, isize)
163-
164-
Calculate the output dimensions given the input dimensions, `isize`.
165-
166-
```jldoctest
167-
julia> using Flux: outdims
168-
169-
julia> m = Dense(10, 5);
170-
171-
julia> outdims(m, (10, 100)) == (5,)
172-
true
173-
174-
julia> outdims(m, (10,)) == (5,)
175-
true
176-
```
177-
"""
178-
function outdims(l::Dense, isize)
179-
first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)),), got $isize"))
180-
return (size(l.W, 1),)
181-
end
182-
183145
"""
184146
Diagonal(in::Integer)
185147
@@ -209,8 +171,6 @@ function Base.show(io::IO, l::Diagonal)
209171
print(io, "Diagonal(", length(l.α), ")")
210172
end
211173

212-
outdims(l::Diagonal, isize) = (length(l.α),)
213-
214174
"""
215175
Maxout(over)
216176
@@ -254,8 +214,6 @@ function (mo::Maxout)(input::AbstractArray)
254214
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
255215
end
256216

257-
outdims(l::Maxout, isize) = outdims(first(l.over), isize)
258-
259217
"""
260218
SkipConnection(layer, connection)
261219

src/layers/conv.jl

-29
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ using NNlib: conv, ∇conv_data, depthwiseconv, output_size
33
# pad dims of x with dims of y until ndims(x) == ndims(y)
44
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
55

6-
_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end])
7-
86
expand(N, i::Tuple) = i
97
expand(N, i::Integer) = ntuple(_ -> i, N)
108

@@ -188,21 +186,6 @@ end
188186
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
189187
a(T.(x))
190188

191-
"""
192-
outdims(l::Conv, isize::Tuple)
193-
194-
Calculate the output dimensions given the input dimensions `isize`.
195-
Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
196-
197-
```julia
198-
m = Conv((3, 3), 3 => 16)
199-
outdims(m, (10, 10)) == (8, 8)
200-
outdims(m, (10, 10, 1, 3)) == (8, 8)
201-
```
202-
"""
203-
outdims(l::Conv, isize) =
204-
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
205-
206189
"""
207190
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
208191
@@ -311,8 +294,6 @@ end
311294
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
312295
a(T.(x))
313296

314-
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
315-
316297
function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T}
317298
calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride)
318299
end
@@ -425,9 +406,6 @@ end
425406
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
426407
a(T.(x))
427408

428-
outdims(l::DepthwiseConv, isize) =
429-
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
430-
431409
"""
432410
CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
433411
@@ -521,9 +499,6 @@ end
521499
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
522500
a(T.(x))
523501

524-
outdims(l::CrossCor, isize) =
525-
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
526-
527502
"""
528503
AdaptiveMaxPool(out::NTuple)
529504
@@ -744,8 +719,6 @@ end
744719
_maybetuple_string(pad) = string(pad)
745720
_maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad)
746721

747-
outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
748-
749722
"""
750723
MeanPool(window::NTuple; pad=0, stride=window)
751724
@@ -798,5 +771,3 @@ function Base.show(io::IO, m::MeanPool)
798771
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
799772
print(io, ")")
800773
end
801-
802-
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))

src/layers/normalise.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -420,4 +420,4 @@ function Base.show(io::IO, l::GroupNorm)
420420
print(io, "GroupNorm($(join(size(l.β), ", "))")
421421
(l.λ == identity) || print(io, ", λ = $(l.λ)")
422422
print(io, ")")
423-
end
423+
end

src/outputsize.jl

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
module NilNumber
2+
3+
using NNlib
4+
5+
"""
6+
Nil <: Number
7+
8+
Nil is a singleton type with a single instance `nil`.
9+
Unlike `Nothing` and `Missing` it subtypes `Number`.
10+
"""
11+
struct Nil <: Number end
12+
13+
const nil = Nil()
14+
15+
Nil(::T) where T<:Number = nil
16+
(::Type{T})(::Nil) where T<:Number = nil
17+
Base.convert(::Type{Nil}, ::Number) = nil
18+
19+
Base.float(::Type{Nil}) = Nil
20+
21+
for f in [:copy, :zero, :one, :oneunit,
22+
:+, :-, :abs, :abs2, :inv,
23+
:exp, :log, :log1p, :log2, :log10,
24+
:sqrt, :tanh, :conj]
25+
@eval Base.$f(::Nil) = nil
26+
end
27+
28+
for f in [:+, :-, :*, :/, :^, :mod, :div, :rem]
29+
@eval Base.$f(::Nil, ::Nil) = nil
30+
end
31+
32+
Base.isless(::Nil, ::Nil) = true
33+
Base.isless(::Nil, ::Number) = true
34+
Base.isless(::Number, ::Nil) = true
35+
36+
Base.isnan(::Nil) = false
37+
38+
Base.typemin(::Type{Nil}) = nil
39+
Base.typemax(::Type{Nil}) = nil
40+
41+
Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil
42+
43+
end # module
44+
45+
using .NilNumber: Nil, nil
46+
47+
"""
48+
outputsize(m, inputsize::Tuple; padbatch=false)
49+
50+
Calculate the output size of model `m` given the input size.
51+
Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`.
52+
Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and
53+
returns the final size including this extra batch dimension.
54+
55+
This should be faster than calling `size(m(x))`. It uses a trivial number type,
56+
and thus should work out of the box for custom layers.
57+
58+
If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`.
59+
60+
# Examples
61+
```jldoctest
62+
julia> using Flux: outputsize
63+
64+
julia> outputsize(Dense(10, 4), (10,); padbatch=true)
65+
(4, 1)
66+
67+
julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32));
68+
69+
julia> m(randn(Float32, 10, 10, 3, 64)) |> size
70+
(6, 6, 32, 64)
71+
72+
julia> outputsize(m, (10, 10, 3); padbatch=true)
73+
(6, 6, 32, 1)
74+
75+
julia> outputsize(m, (10, 10, 3, 64))
76+
(6, 6, 32, 64)
77+
78+
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
79+
DimensionMismatch("Input channels must match! (7 vs. 3)")
80+
81+
julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1))
82+
(2, 1)
83+
84+
julia> using LinearAlgebra: norm
85+
86+
julia> f(x) = x ./ norm.(eachcol(x));
87+
88+
julia> outputsize(f, (10, 1)) # manually specify batch size as 1
89+
(10, 1)
90+
91+
julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size
92+
(10, 1)
93+
```
94+
"""
95+
function outputsize(m, inputsize::Tuple; padbatch=false)
96+
inputsize = padbatch ? (inputsize..., 1) : inputsize
97+
98+
return size(m(fill(nil, inputsize)))
99+
end
100+
101+
## make tuples and vectors be like Chains
102+
103+
outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch)
104+
outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch)
105+
106+
## bypass statistics in normalization layers
107+
108+
for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm)
109+
@eval (l::$layer)(x::AbstractArray{Nil}) = x
110+
end
111+
112+
## fixes for layers that don't work out of the box
113+
114+
for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
115+
@eval begin
116+
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)
117+
fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end])
118+
end
119+
120+
function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims)
121+
NNlib.$fn(fill(nil, size(a)), b, dims)
122+
end
123+
124+
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims)
125+
NNlib.$fn(a, fill(nil, size(b)), dims)
126+
end
127+
end
128+
end

0 commit comments

Comments
 (0)