Skip to content

Commit 87a0065

Browse files
bors[bot]Michael Abbottmcabbott
authored
Merge #1467
1467: show(::Chain) r=CarloLucibello a=mcabbott <img width="600" alt="Screenshot 2021-01-15 at 21 52 03" src="https://user-images.githubusercontent.com/32575566/104777585-2e57cf80-577c-11eb-83ba-7a4605fcdb00.png"> Not quite done, RFC? Co-authored-by: Michael Abbott <me@escbook> Co-authored-by: Michael Abbott <[email protected]>
2 parents 726bbdd + c91e731 commit 87a0065

File tree

10 files changed

+258
-52
lines changed

10 files changed

+258
-52
lines changed

docs/src/utilities.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ To change the default on an applicable layer, pass the desired function with the
2828

2929
```jldoctest; setup = :(using Flux)
3030
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
31-
Conv((3, 3), 1=>8, relu)
31+
Conv((3, 3), 1 => 8, relu) # 80 parameters
3232
```
3333

3434
```@docs

src/Flux.jl

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ include("layers/conv.jl")
4343
include("layers/recurrent.jl")
4444
include("layers/normalise.jl")
4545
include("layers/upsample.jl")
46+
include("layers/show.jl")
4647

4748
include("outputsize.jl")
4849

src/functor.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import Adapt: adapt, adapt_storage
22
using LinearAlgebra: Cholesky
33
using Zygote: IdSet
4-
import Functors: @functor, functor, fmap
5-
import Functors
4+
import Functors: Functors, @functor, functor, fmap, isleaf
65

76
trainable(m) = functor(m)[1]
87

src/layers/basic.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided
8989
# Examples
9090
```jldoctest
9191
julia> d = Dense(5, 2)
92-
Dense(5, 2)
92+
Dense(5, 2) # 12 parameters
9393
9494
julia> d(rand(Float32, 5, 64)) |> size
9595
(2, 64)
@@ -98,7 +98,7 @@ julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimension
9898
(2, 1, 1, 64)
9999
100100
julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
101-
Dense(5, 2, tanh; bias=false)
101+
Dense(5, 2, tanh; bias=false) # 10 parameters
102102
103103
julia> d1(ones(5))
104104
2-element Vector{Float64}:
@@ -395,7 +395,11 @@ julia> size(model(rand(3)))
395395
(17,)
396396
397397
julia> model = Parallel(+, Dense(10, 2), Dense(5, 2))
398-
Parallel(+, Dense(10, 2), Dense(5, 2))
398+
Parallel(
399+
+,
400+
Dense(10, 2), # 22 parameters
401+
Dense(5, 2), # 12 parameters
402+
) # Total: 4 arrays, 34 parameters, 392 bytes.
399403
400404
julia> size(model(rand(10), rand(5)))
401405
(2,)
@@ -417,8 +421,10 @@ Parallel(connection, layers...) = Parallel(connection, layers)
417421
Base.getindex(m::Parallel, i::Integer) = m.layers[i]
418422
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)
419423

424+
trainable(m::Parallel) = (m.connection, m.layers...)
425+
420426
function Base.show(io::IO, m::Parallel)
421427
print(io, "Parallel(", m.connection, ", ")
422428
join(io, m.layers, ", ")
423429
print(io, ")")
424-
end
430+
end

src/layers/conv.jl

+41-30
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
6868
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images
6969
7070
julia> lay = Conv((5,5), 3 => 7, relu; bias=false)
71-
Conv((5, 5), 3=>7, relu)
71+
Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters
7272
7373
julia> lay(xs) |> size
7474
(96, 96, 7, 50)
@@ -99,7 +99,7 @@ end
9999
Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
100100
101101
Constructs a convolutional layer with the given weight and bias.
102-
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
102+
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3 => 7, relu)`
103103
method.
104104
105105
# Examples
@@ -109,7 +109,7 @@ julia> weight = rand(3, 4, 5);
109109
julia> bias = zeros(5);
110110
111111
julia> c1 = Conv(weight, bias, sigmoid) # expects 1 spatial dimension
112-
Conv((3,), 4=>5, σ)
112+
Conv((3,), 4 => 5, σ) # 65 parameters
113113
114114
julia> c1(randn(100, 4, 64)) |> size
115115
(98, 5, 64)
@@ -135,7 +135,7 @@ function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
135135
end
136136

137137
"""
138-
convfilter(filter::Tuple, in=>out)
138+
convfilter(filter::Tuple, in => out)
139139
140140
Constructs a standard convolutional weight matrix with given `filter` and
141141
channels from `in` to `out`.
@@ -160,11 +160,18 @@ end
160160

161161
function Base.show(io::IO, l::Conv)
162162
print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2])
163-
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
164-
l.σ == identity || print(io, ", ", l.σ)
163+
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
164+
_print_conv_opt(io, l)
165165
print(io, ")")
166166
end
167167

168+
function _print_conv_opt(io::IO, l)
169+
l.σ == identity || print(io, ", ", l.σ)
170+
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
171+
all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride))
172+
all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation))
173+
l.bias == Zeros() && print(io, ", bias=false")
174+
end
168175

169176
"""
170177
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
@@ -185,15 +192,15 @@ See also [`Conv`](@ref) for more detailed description of keywords.
185192
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
186193
187194
julia> lay = ConvTranspose((5,5), 3 => 7, relu)
188-
ConvTranspose((5, 5), 3=>7, relu)
195+
ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters
189196
190197
julia> lay(xs) |> size
191198
(104, 104, 7, 50)
192199
193-
julia> ConvTranspose((5,5), 3=>7, stride=2)(xs) |> size
200+
julia> ConvTranspose((5,5), 3 => 7, stride=2)(xs) |> size
194201
(203, 203, 7, 50)
195202
196-
julia> ConvTranspose((5,5), 3=>7, stride=3, pad=SamePad())(xs) |> size
203+
julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size
197204
(300, 300, 7, 50)
198205
```
199206
"""
@@ -210,7 +217,7 @@ end
210217
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])
211218
212219
Constructs a layer with the given weight and bias arrays.
213-
Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
220+
Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method.
214221
"""
215222
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
216223
stride = 1, pad = 0, dilation = 1) where {T,N}
@@ -256,8 +263,8 @@ end
256263

257264
function Base.show(io::IO, l::ConvTranspose)
258265
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
259-
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
260-
l.σ == identity || print(io, ", ", l.σ)
266+
print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1))
267+
_print_conv_opt(io, l)
261268
print(io, ")")
262269
end
263270

@@ -267,7 +274,7 @@ function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilat
267274
end
268275

269276
"""
270-
DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
277+
DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
271278
272279
Depthwise convolutional layer. `filter` is a tuple of integers
273280
specifying the size of the convolutional kernel, while
@@ -285,7 +292,7 @@ See also [`Conv`](@ref) for more detailed description of keywords.
285292
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
286293
287294
julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false)
288-
DepthwiseConv((5, 5), 3=>6, relu)
295+
DepthwiseConv((5, 5), 3 => 6, relu, bias=false) # 150 parameters
289296
290297
julia> lay(xs) |> size
291298
(96, 96, 6, 50)
@@ -307,7 +314,7 @@ end
307314
DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
308315
309316
Constructs a layer with the given weight and bias arrays.
310-
Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
317+
Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method.
311318
"""
312319
function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
313320
stride = 1, pad = 0, dilation = 1) where {T,N}
@@ -328,7 +335,7 @@ end
328335
@functor DepthwiseConv
329336

330337
"""
331-
depthwiseconvfilter(filter::Tuple, in=>out)
338+
depthwiseconvfilter(filter::Tuple, in => out)
332339
333340
Constructs a depthwise convolutional weight array defined by `filter` and channels
334341
from `in` to `out`.
@@ -349,8 +356,8 @@ end
349356

350357
function Base.show(io::IO, l::DepthwiseConv)
351358
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
352-
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
353-
l.σ == identity || print(io, ", ", l.σ)
359+
print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end]))
360+
_print_conv_opt(io, l)
354361
print(io, ")")
355362
end
356363

@@ -373,12 +380,12 @@ See also [`Conv`](@ref) for more detailed description of keywords.
373380
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
374381
375382
julia> lay = CrossCor((5,5), 3 => 6, relu; bias=false)
376-
CrossCor((5, 5), 3=>6, relu)
383+
CrossCor((5, 5), 3 => 6, relu, bias=false) # 450 parameters
377384
378385
julia> lay(xs) |> size
379386
(96, 96, 6, 50)
380387
381-
julia> CrossCor((5,5), 3=>7, stride=3, pad=(2,0))(xs) |> size
388+
julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size
382389
(34, 32, 7, 50)
383390
```
384391
"""
@@ -395,7 +402,7 @@ end
395402
CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation])
396403
397404
Constructs a layer with the given weight and bias arrays.
398-
Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
405+
Accepts the same keywords as the `CrossCor((4,4), 3 => 7, relu)` method.
399406
"""
400407
function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
401408
stride = 1, pad = 0, dilation = 1) where {T,N}
@@ -430,8 +437,8 @@ end
430437

431438
function Base.show(io::IO, l::CrossCor)
432439
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
433-
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
434-
l.σ == identity || print(io, ", ", l.σ)
440+
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
441+
_print_conv_opt(io, l)
435442
print(io, ")")
436443
end
437444

@@ -530,8 +537,7 @@ See also [`MaxPool`](@ref), [`GlobalMeanPool`](@ref).
530537
```jldoctest
531538
julia> xs = rand(Float32, 100, 100, 3, 50);
532539
533-
julia> m = Chain(Conv((3,3), 3=>7), GlobalMaxPool())
534-
Chain(Conv((3, 3), 3=>7), GlobalMaxPool())
540+
julia> m = Chain(Conv((3,3), 3 => 7), GlobalMaxPool());
535541
536542
julia> m(xs) |> size
537543
(1, 1, 7, 50)
@@ -568,8 +574,7 @@ by performing mean pooling on the complete (w,h)-shaped feature maps.
568574
```jldoctest
569575
julia> xs = rand(Float32, 100, 100, 3, 50);
570576
571-
julia> m = Chain(Conv((3,3), 3=>7), GlobalMeanPool())
572-
Chain(Conv((3, 3), 3=>7), GlobalMeanPool())
577+
julia> m = Chain(Conv((3,3), 3 => 7), GlobalMeanPool());
573578
574579
julia> m(xs) |> size
575580
(1, 1, 7, 50)
@@ -612,8 +617,11 @@ See also [`Conv`](@ref), [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref), [`Global
612617
```jldoctest
613618
julia> xs = rand(Float32, 100, 100, 3, 50); # batch of 50 RGB images
614619
615-
julia> m = Chain(Conv((5, 5), 3=>7, pad=SamePad()), MaxPool((5, 5), pad=SamePad()))
616-
Chain(Conv((5, 5), 3=>7), MaxPool((5, 5), pad=2))
620+
julia> m = Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad()))
621+
Chain(
622+
Conv((5, 5), 3 => 7, pad=2), # 532 parameters
623+
MaxPool((5, 5), pad=2),
624+
)
617625
618626
julia> m[1](xs) |> size
619627
(100, 100, 7, 50)
@@ -675,7 +683,10 @@ See also [`Conv`](@ref), [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref).
675683
julia> xs = rand(Float32, 100, 100, 3, 50);
676684
677685
julia> m = Chain(Conv((5,5), 3 => 7), MeanPool((5,5), pad=SamePad()))
678-
Chain(Conv((5, 5), 3=>7), MeanPool((5, 5), pad=2))
686+
Chain(
687+
Conv((5, 5), 3 => 7), # 532 parameters
688+
MeanPool((5, 5), pad=2),
689+
)
679690
680691
julia> m[1](xs) |> size
681692
(96, 96, 7, 50)

src/layers/normalise.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ testmode!(m::BatchNorm, mode=true) =
278278

279279
function Base.show(io::IO, l::BatchNorm)
280280
print(io, "BatchNorm($(l.chs)")
281-
l.λ == identity || print(io, ", $(l.λ)")
281+
(l.λ == identity) || print(io, ", $(l.λ)")
282282
hasaffine(l) || print(io, ", affine=false")
283283
print(io, ")")
284284
end
@@ -443,8 +443,9 @@ testmode!(m::GroupNorm, mode = true) =
443443
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
444444

445445
function Base.show(io::IO, l::GroupNorm)
446+
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
446447
print(io, "GroupNorm($(l.chs), $(l.G)")
447-
l.λ == identity || print(io, ", $(l.λ)")
448+
l.λ == identity || print(io, ", ", l.λ)
448449
hasaffine(l) || print(io, ", affine=false")
449450
print(io, ")")
450451
end

0 commit comments

Comments
 (0)