|
| 1 | +module NilNumber |
| 2 | + |
| 3 | +using NNlib |
| 4 | + |
| 5 | +""" |
| 6 | + Nil <: Number |
| 7 | +
|
| 8 | +Nil is a singleton type with a single instance `nil`. |
| 9 | +Unlike `Nothing` and `Missing` it subtypes `Number`. |
| 10 | +""" |
| 11 | +struct Nil <: Number end |
| 12 | + |
| 13 | +const nil = Nil() |
| 14 | + |
| 15 | +Nil(::T) where T<:Number = nil |
| 16 | +(::Type{T})(::Nil) where T<:Number = nil |
| 17 | +Base.convert(::Type{Nil}, ::Number) = nil |
| 18 | + |
| 19 | +Base.float(::Type{Nil}) = Nil |
| 20 | + |
| 21 | +for f in [:copy, :zero, :one, :oneunit, |
| 22 | + :+, :-, :abs, :abs2, :inv, |
| 23 | + :exp, :log, :log1p, :log2, :log10, |
| 24 | + :sqrt, :tanh, :conj] |
| 25 | + @eval Base.$f(::Nil) = nil |
| 26 | +end |
| 27 | + |
| 28 | +for f in [:+, :-, :*, :/, :^, :mod, :div, :rem] |
| 29 | + @eval Base.$f(::Nil, ::Nil) = nil |
| 30 | +end |
| 31 | + |
| 32 | +Base.isless(::Nil, ::Nil) = true |
| 33 | +Base.isless(::Nil, ::Number) = true |
| 34 | +Base.isless(::Number, ::Nil) = true |
| 35 | + |
| 36 | +Base.isnan(::Nil) = false |
| 37 | + |
| 38 | +Base.typemin(::Type{Nil}) = nil |
| 39 | +Base.typemax(::Type{Nil}) = nil |
| 40 | + |
| 41 | +Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil |
| 42 | + |
| 43 | +end # module |
| 44 | + |
| 45 | +using .NilNumber: Nil, nil |
| 46 | + |
| 47 | +""" |
| 48 | + outputsize(m, inputsize::Tuple; padbatch=false) |
| 49 | +
|
| 50 | +Calculate the output size of model `m` given the input size. |
| 51 | +Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`. |
| 52 | +Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and |
| 53 | +returns the final size including this extra batch dimension. |
| 54 | +
|
| 55 | +This should be faster than calling `size(m(x))`. It uses a trivial number type, |
| 56 | +and thus should work out of the box for custom layers. |
| 57 | +
|
| 58 | +If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`. |
| 59 | +
|
| 60 | +# Examples |
| 61 | +```jldoctest |
| 62 | +julia> using Flux: outputsize |
| 63 | +
|
| 64 | +julia> outputsize(Dense(10, 4), (10,); padbatch=true) |
| 65 | +(4, 1) |
| 66 | +
|
| 67 | +julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); |
| 68 | +
|
| 69 | +julia> m(randn(Float32, 10, 10, 3, 64)) |> size |
| 70 | +(6, 6, 32, 64) |
| 71 | +
|
| 72 | +julia> outputsize(m, (10, 10, 3); padbatch=true) |
| 73 | +(6, 6, 32, 1) |
| 74 | +
|
| 75 | +julia> outputsize(m, (10, 10, 3, 64)) |
| 76 | +(6, 6, 32, 64) |
| 77 | +
|
| 78 | +julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end |
| 79 | +DimensionMismatch("Input channels must match! (7 vs. 3)") |
| 80 | +
|
| 81 | +julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) |
| 82 | +(2, 1) |
| 83 | +
|
| 84 | +julia> using LinearAlgebra: norm |
| 85 | +
|
| 86 | +julia> f(x) = x ./ norm.(eachcol(x)); |
| 87 | +
|
| 88 | +julia> outputsize(f, (10, 1)) # manually specify batch size as 1 |
| 89 | +(10, 1) |
| 90 | +
|
| 91 | +julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size |
| 92 | +(10, 1) |
| 93 | +``` |
| 94 | +""" |
| 95 | +function outputsize(m, inputsize::Tuple; padbatch=false) |
| 96 | + inputsize = padbatch ? (inputsize..., 1) : inputsize |
| 97 | + |
| 98 | + return size(m(fill(nil, inputsize))) |
| 99 | +end |
| 100 | + |
| 101 | +## make tuples and vectors be like Chains |
| 102 | + |
| 103 | +outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) |
| 104 | +outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) |
| 105 | + |
| 106 | +## bypass statistics in normalization layers |
| 107 | + |
| 108 | +for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm) |
| 109 | + @eval (l::$layer)(x::AbstractArray{Nil}) = x |
| 110 | +end |
| 111 | + |
| 112 | +## fixes for layers that don't work out of the box |
| 113 | + |
| 114 | +for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) |
| 115 | + @eval begin |
| 116 | + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) |
| 117 | + fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end]) |
| 118 | + end |
| 119 | + |
| 120 | + function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) |
| 121 | + NNlib.$fn(fill(nil, size(a)), b, dims) |
| 122 | + end |
| 123 | + |
| 124 | + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) |
| 125 | + NNlib.$fn(a, fill(nil, size(b)), dims) |
| 126 | + end |
| 127 | + end |
| 128 | +end |
0 commit comments