diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 86bab1896..011763160 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -64,6 +64,7 @@ export SGConv, # layers/pool + ConcatPool, GlobalPool, GlobalAttentionPool, TopKPool, diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 88ed3ddd0..c9c6e36c8 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,5 +1,40 @@ using DataStructures: nlargest +@doc raw""" + ConcatPool(pooling_layer) + +```math +\mathbf{x}_i' = [\mathbf{x}_i; \mathbf{u}_V] +``` + +# Arguments + +- `pooling_layer`: + +# Examples + +```julia +using Flux, GraphNeuralNetworks, Graphs + +add_pool = ConcatPool(GlobalPool(mean)) + +g = GNNGraph(rand_graph(10, 4)) +X = rand(32, 10) +pool(g, X) # => 64x10 matrix +``` +""" +struct ConcatPool <: GNNLayer + pool::GNNLayer +end + +function (l::ConcatPool)(g::GNNGraph, x::AbstractArray) + g_feat = applylayer(l.pool, g, x) + feat_arr = broadcast_nodes(g, g_feat) + return vcat(x, feat_arr) +end + +(l::ConcatPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) + @doc raw""" GlobalPool(aggr) @@ -44,7 +79,6 @@ end (l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) - @doc raw""" GlobalAttentionPool(fgate, ffeat=identity) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index f7bb74a83..6d84b01d4 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -21,6 +21,35 @@ test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph) end + @testset "ConcatPool" begin + p = GlobalPool(+) + l = ConcatPool(p) + n = 10 + chin = 6 + X = rand(Float32, chin, n) + g = GNNGraph(random_regular_graph(n, 4), ndata=X, graph_type=GRAPH_T) + y = p(g, X) + u = l(g, X) + + @test size(u) == (chin*2, n) + @test u[1:chin,:] ≈ X + @test u[chin+1:end,:] ≈ repeat(y, 1, n) + + n = [5, 6, 7] + ng = length(n) + g = Flux.batch([GNNGraph(random_regular_graph(n[i], 4), + ndata=rand(Float32, chin, n[i]), + graph_type=GRAPH_T) + for i=1:ng]) + y = p(g, g.ndata.x) + u = l(g, g.ndata.x) + @test size(u) == (chin*2, sum(n)) + @test u[1:chin,:] ≈ g.ndata.x + @test u[chin+1:end,:] ≈ hcat([repeat(y[:,i], 1, n[i]) for i=1:ng]...) + + test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph) + end + @testset "GlobalAttentionPool" begin n = 10 chin = 6