@@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
2828true
2929```
3030
31+ A chain may be called with multiple arguments, which is equivalent to calling it
32+ with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref)
33+ to mean the same as several arguments:
34+
35+ ```jldoctest
36+ julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple
37+ (1, 2, 3)
38+ nothing
39+
40+ julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2
41+ x = (4, 5)
42+ 25.25
43+ ```
44+
3145For large models, there is a special type-unstable path which can reduce compilation
3246times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
3347This feature is somewhat experimental, beware!
4660@forward Chain. layers Base. getindex, Base. length, Base. first, Base. last,
4761 Base. iterate, Base. lastindex, Base. keys, Base. firstindex
4862
49- @layer :expand Chain # the + opts-in to container-style pretty-printing
63+ @layer :expand Chain # the option :expand opts-in to container-style pretty-printing
5064
5165(c:: Chain )(x) = _applychain (c. layers, x)
66+ (c:: Chain )(x, ys... ) = _applychain (c. layers, (x, ys... ))
5267
5368@generated function _applychain (layers:: Tuple{Vararg{Any,N}} , x) where {N}
5469 symbols = vcat (:x , [gensym () for _ in 1 : N])
6883Base. getindex (c:: Chain , i:: AbstractArray ) = Chain (c. layers[i])
6984Base. getindex (c:: Chain{<:NamedTuple} , i:: AbstractArray ) =
7085 Chain (NamedTuple {keys(c)[i]} (Tuple (c. layers)[i]))
86+
7187function Base. show (io:: IO , c:: Chain )
7288 print (io, " Chain(" )
7389 _show_layers (io, c. layers)
475491Create a layer which passes an input array to each path in
476492`layers`, before reducing the output with `connection`.
477493
478- Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
479- If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
494+ Obeys the similar rules to broadcasting:
495+ * Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
496+ * With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`.
497+ * With multiple inputs and multiple layers, one input is passed to each layer,
498+ thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
480499
481500Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
482501These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
@@ -486,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`.
486505
487506# Examples
488507
508+ ```jldoctest
509+ julia> p = Parallel(+, abs2, sqrt);
510+
511+ julia> p(3, 4) # == 3^2 + √4, two functions two inputs
512+ 11.0
513+
514+ julia> p((3, 4)) # tuple is always splatted
515+ 11.0
516+
517+ julia> p(4) # == 4^2 + √4, one input used twice
518+ 18.0
519+
520+ julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs
521+ 1×3 Matrix{Float64}:
522+ 1.0 0.5 0.25
523+ ```
524+
525+ With Flux layers:
526+
489527```jldoctest
490528julia> model = Chain(Dense(3 => 5),
491529 Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
@@ -516,35 +554,47 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}}
516554 layers:: T
517555end
518556
557+ _ParallelONE{T} = Parallel{T, <: Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}} }
558+
519559Parallel (connection, layers... ) = Parallel (connection, layers)
520560function Parallel (connection; kw... )
521561 layers = NamedTuple (kw)
522562 if :layers in keys (layers) || :connection in keys (layers)
523563 throw (ArgumentError (" a Parallel layer cannot have a named sub-layer called `connection` or `layers`" ))
524564 end
525- isempty (layers) && return Parallel (connection, ())
526565 Parallel (connection, layers)
527566end
567+ Parallel (connection, layers:: Union{Tuple{}, @NamedTuple{}} ) =
568+ throw (ArgumentError (" cannot construct a Parallel layer with no sub-layers" ))
528569
529570@layer :expand Parallel
530571
531- (m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... )
532- (m:: Parallel )(xs:: Tuple ) = m (xs... )
572+ (m:: Parallel )(x) = m. connection (map (f -> f (x), Tuple (m. layers))... ) # one argument
533573
534574function _parallel_check (layers, xs)
535575 nl = length (layers)
536- nx = length (xs)
576+ @assert nl > 1 # dispatch handles nl==1 cases
577+ nx = length (xs)
537578 if (nl != nx)
538- throw (ArgumentError (lazy " Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs" ))
579+ throw (ArgumentError (lazy " Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs" ))
539580 end
540581end
541582ChainRulesCore. @non_differentiable _parallel_check (nl, nx)
542583
543- function (m:: Parallel )(xs... )
584+ function (m:: Parallel )(x, ys... )
585+ xs = (x, ys... )
544586 _parallel_check (m. layers, xs)
545- m. connection (map (|> , xs, Tuple (m. layers))... )
587+ m. connection (map (|> , xs, Tuple (m. layers))... ) # multiple arguments & multiple layers
546588end
547589
590+ (m:: _ParallelONE )(x, ys... ) =
591+ m. connection (map (z -> only (m. layers)(z), (x, ys... ))... ) # multiple arguments, one layer
592+
593+ (m:: Parallel )(xs:: Tuple ) = m (xs... ) # tuple is always splatted
594+ (m:: _ParallelONE )(xs:: Tuple ) = m (xs... ) # solves an ambiguity
595+
596+ (m:: Parallel )() = throw (ArgumentError (" Parallel layer cannot take 0 inputs" ))
597+
548598Base. getindex (m:: Parallel , i) = m. layers[i]
549599Base. getindex (m:: Parallel , i:: AbstractVector ) = Parallel (m. connection, m. layers[i])
550600Base. getindex (m:: Parallel{<:Any, <:NamedTuple} , i:: AbstractVector ) =
0 commit comments