From 4af876b06f09249a2601b8357485b64130d955ad Mon Sep 17 00:00:00 2001 From: abieler Date: Thu, 28 Dec 2023 13:45:02 +0100 Subject: [PATCH] First draft GPS conv layer --- src/layers/conv.jl | 210 +++++++++++++++++++++++++++++---------------- 1 file changed, 137 insertions(+), 73 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c1ea60b1e..06dd73506 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -74,33 +74,31 @@ end @functor GCNConv function GCNConv(ch::Pair{Int, Int}, σ = identity; - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true, - use_edge_weight = false) + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + use_edge_weight = false) in, out = ch W = init(out, in) b = bias ? Flux.create_bias(W, true, out) : false GCNConv(W, b, σ, add_self_loops, use_edge_weight) end -check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = +function check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) +end function check_gcnconv_input(g::GNNGraph, edge_weight::AbstractVector) - if length(edge_weight) !== g.num_edges + if length(edge_weight) !== g.num_edges throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))")) end end check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing - -function (l::GCNConv)(g::GNNGraph, - x::AbstractMatrix{T}, - edge_weight::EW = nothing - ) where {T, EW <: Union{Nothing, AbstractVector}} - +function (l::GCNConv)(g::GNNGraph, + x::AbstractMatrix{T}, + edge_weight::EW = nothing) where {T, EW <: Union{Nothing, AbstractVector}} check_gcnconv_input(g, edge_weight) if l.add_self_loops @@ -139,7 +137,7 @@ function (l::GCNConv)(g::GNNGraph, end function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) + edge_weight::AbstractVector) g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO return l(g, x, edge_weight) end @@ -190,7 +188,7 @@ struct ChebConv{W <: AbstractArray{<:Number, 3}, B} <: GNNLayer end function ChebConv(ch::Pair{Int, Int}, k::Int; - init = glorot_uniform, bias::Bool = true) + init = glorot_uniform, bias::Bool = true) in, out = ch W = init(out, in, k) b = bias ? Flux.create_bias(W, true, out) : false @@ -223,6 +221,60 @@ function Base.show(io::IO, l::ChebConv) print(io, ")") end +struct DotProductAttention{W <: AbstractArray} <: GNNLayer + channel::Pair{Int, Int} + Wq::W + Wk::W + Wv::W +end + +function DotProductAttention(ch::Pair{Int, Int}, init = glorot_uniform) + chin, chout = ch + Wq = init(chin, chin) + Wk = init(chin, chin) + Wv = init(chout, chin) + DotProductAttention(ch, Wq, Wk, Wv) +end + +function (l::DotProductAttention)(g::AbstractGNNGraph, x) + q = l.Wq * x + k = l.Wk * x + v = l.Wv * x + x, _ = dot_product_attention(q, k, v) + return x +end +@functor DotProductAttention + +@doc raw""" + GPSConv(in=>out, conv=GCNConv, σ=identity; bias=true, heads=5, init=glorot_uniform) +""" +struct GPSConv <: GNNLayer + convlayer::GNNChain + attnlayer::GNNChain + ffn::GNNChain +end + +function GPSConv(ch::Pair{Int, Int}, gconv, σ = identity; init = glorot_uniform) + in, out = ch + gattn = DotProductAttention(ch) + convlayer = GNNChain(gconv, Dropout(0.5), LayerNorm(out)) + attnlayer = GNNChain(gattn, Dropout(0.5), LayerNorm(out)) + ffn = GNNChain(Dense(in => 2 * in, σ), Dropout(0.5), Dense(2 * in => in), Dropout(0.5)) + return GPSConv(convlayer, attnlayer, ffn) +end + +function (l::GPSConv)(g::AbstractGNNGraph, x) + check_num_nodes(g, x) + f = GNNChain( + Parallel(+, + Parallel(+, l.convlayer, identity), + Parallel(+, l.attnlayer, identity), + ), + l.ffn + ) + f(g, x) +end + @doc raw""" GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform) @@ -255,7 +307,7 @@ end @functor GraphConv function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, - init = glorot_uniform, bias::Bool = true) + init = glorot_uniform, bias::Bool = true) in, out = ch W1 = init(out, in) W2 = init(out, in) @@ -329,13 +381,15 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, end @functor GATConv -Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a) +function Flux.trainable(l::GATConv) + (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a) +end GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; - heads::Int = 1, concat::Bool = true, negative_slope = 0.2, - init = glorot_uniform, bias::Bool = true, add_self_loops = true) + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init = glorot_uniform, bias::Bool = true, add_self_loops = true) (in, ein), out = ch if add_self_loops @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." @@ -352,7 +406,7 @@ end (l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, - e::Union{Nothing, AbstractMatrix} = nothing) + e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" @@ -457,20 +511,22 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer end @functor GATv2Conv -Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) +function Flux.trainable(l::GATv2Conv) + (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) +end function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) GATv2Conv((ch[1], 0) => ch[2], args...; kws...) end function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, - σ = identity; - heads::Int = 1, - concat::Bool = true, - negative_slope = 0.2, - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true) + σ = identity; + heads::Int = 1, + concat::Bool = true, + negative_slope = 0.2, + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true) (in, ein), out = ch if add_self_loops @@ -488,13 +544,13 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, a = init(out, heads) negative_slope = convert(eltype(dense_i.weight), negative_slope) GATv2Conv(dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat, - add_self_loops) + add_self_loops) end (l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, - e::Union{Nothing, AbstractMatrix} = nothing) + e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" @@ -575,7 +631,7 @@ end @functor GatedGraphConv function GatedGraphConv(out_ch::Int, num_layers::Int; - aggr = +, init = glorot_uniform) + aggr = +, init = glorot_uniform) w = init(out_ch, out_ch, num_layers) gru = GRUCell(out_ch, out_ch) GatedGraphConv(w, gru, out_ch, num_layers, aggr) @@ -724,7 +780,7 @@ end @functor NNConv function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, - init = glorot_uniform) + init = glorot_uniform) in, out = ch W = init(out, in) b = bias ? Flux.create_bias(W, true, out) : false @@ -786,7 +842,7 @@ end @functor SAGEConv function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, - init = glorot_uniform, bias::Bool = true) + init = glorot_uniform, bias::Bool = true) in, out = ch W = init(out, 2 * in) b = bias ? Flux.create_bias(W, true, out) : false @@ -844,7 +900,7 @@ end @functor ResGatedGraphConv function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; - init = glorot_uniform, bias::Bool = true) + init = glorot_uniform, bias::Bool = true) in, out = ch A = init(out, in) B = init(out, in) @@ -932,7 +988,7 @@ end CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, - bias = true, init = glorot_uniform) + bias = true, init = glorot_uniform) (nin, ein), out = ch dense_f = Dense(2nin + ein, out, sigmoid; bias, init) dense_s = Dense(2nin + ein, out, act; bias, init) @@ -940,7 +996,7 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, end function (l::CGConv)(g::GNNGraph, x::AbstractMatrix, - e::Union{Nothing, AbstractMatrix} = nothing) + e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) if e !== nothing check_num_edges(g, e) @@ -1030,7 +1086,7 @@ function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix) α = softmax_edge_neighbors(g, l.β .* cos_dist) x = propagate(g, +; xj = x, e = α) do xi, xj, α - α .* xj + α .* xj end return x @@ -1080,10 +1136,10 @@ MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) nin, nout = ch ϕe = Chain(Dense(3nin, nout, relu), - Dense(nout, nout)) + Dense(nout, nout)) ϕv = Chain(Dense(nin + nout, nout, relu), - Dense(nout, nout)) + Dense(nout, nout)) MEGNetConv(ϕe, ϕv; aggr) end @@ -1169,11 +1225,11 @@ end @functor GMMConv function GMMConv(ch::Pair{NTuple{2, Int}, Int}, - σ = identity; - K::Int = 1, - bias::Bool = true, - init = Flux.glorot_uniform, - residual = false) + σ = identity; + K::Int = 1, + bias::Bool = true, + init = Flux.glorot_uniform, + residual = false) (nin, ein), out = ch mu = init(ein, K) sigma_inv = init(ein, K) @@ -1281,10 +1337,10 @@ end @functor SGConv function SGConv(ch::Pair{Int, Int}, k = 1; - init = glorot_uniform, - bias::Bool = true, - add_self_loops = true, - use_edge_weight = false) + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + use_edge_weight = false) in, out = ch W = init(out, in) b = bias ? Flux.create_bias(W, true, out) : false @@ -1292,8 +1348,8 @@ function SGConv(ch::Pair{Int, Int}, k = 1; end function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T}, - edge_weight::EW = nothing) where - {T, EW <: Union{Nothing, AbstractVector}} + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" if edge_weight !== nothing @@ -1314,7 +1370,7 @@ function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T}, if edge_weight !== nothing d = degree(g, T; dir = :in, edge_weight) else - d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) end c = 1 ./ sqrt.(d) for iter in 1:(l.k) @@ -1335,7 +1391,7 @@ function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T}, end function (l::SGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) + edge_weight::AbstractVector) g = GNNGraph(edge_index(g)...; g.num_nodes) return l(g, x, edge_weight) end @@ -1417,22 +1473,22 @@ end #Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], - residual = false) + residual = false) (in_size, edge_feat_size), out_size = ch act_fn = swish # +1 for the radial feature: ||x_i - x_j||^2 ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), - Dense(hidden_size => hidden_size, act_fn)) + Dense(hidden_size => hidden_size, act_fn)) ϕh = Chain(Dense(in_size + hidden_size, hidden_size, swish), - Dense(hidden_size, out_size)) + Dense(hidden_size, out_size)) ϕx = Chain(Dense(hidden_size, hidden_size, swish), - Dense(hidden_size, 1, bias = false)) + Dense(hidden_size, 1, bias = false)) num_features = (in = in_size, edge = edge_feat_size, out = out_size, - hidden = hidden_size) + hidden = hidden_size) if residual @assert in_size==out_size "Residual connection only possible if in_size == out_size" end @@ -1450,7 +1506,7 @@ function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = no x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) msg = apply_edges(message, g, l, - xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) + xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) h_aggr = aggregate_neighbors(g, +, msg.h) x_aggr = aggregate_neighbors(g, mean, msg.x) @@ -1569,7 +1625,15 @@ end @functor TransformerConv function Flux.trainable(l::TransformerConv) - (W1 = l.W1, W2 = l.W2, W3 = l.W3, W4 = l.W4, W5 = l.W5, W6 = l.W6, FF = l.FF, BN1 = l.BN1, BN2 = l.BN2) + (W1 = l.W1, + W2 = l.W2, + W3 = l.W3, + W4 = l.W4, + W5 = l.W5, + W6 = l.W6, + FF = l.FF, + BN1 = l.BN1, + BN2 = l.BN2) end function TransformerConv(ch::Pair{Int, Int}, args...; kws...) @@ -1577,17 +1641,17 @@ function TransformerConv(ch::Pair{Int, Int}, args...; kws...) end function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; - heads::Int = 1, - concat::Bool = true, - init = glorot_uniform, - add_self_loops::Bool = false, - bias_qkv = true, - bias_root::Bool = true, - root_weight::Bool = true, - gating::Bool = false, - skip_connection::Bool = false, - batch_norm::Bool = false, - ff_channels::Int = 0) + heads::Int = 1, + concat::Bool = true, + init = glorot_uniform, + add_self_loops::Bool = false, + bias_qkv = true, + bias_root::Bool = true, + root_weight::Bool = true, + gating::Bool = false, + skip_connection::Bool = false, + batch_norm::Bool = false, + ff_channels::Int = 0) (in, ein), out = ch if add_self_loops @@ -1604,17 +1668,17 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; W6 = ein > 0 ? Dense(ein => out * heads; bias = bias_qkv, init = init) : nothing FF = ff_channels > 0 ? Chain(Dense(out_mha => ff_channels, relu), - Dense(ff_channels => out_mha)) : nothing + Dense(ff_channels => out_mha)) : nothing BN1 = batch_norm ? BatchNorm(out_mha) : nothing BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing return TransformerConv(W1, W2, W3, W4, W5, W6, FF, BN1, BN2, - ch, heads, add_self_loops, concat, skip_connection, - Float32(√out)) + ch, heads, add_self_loops, concat, skip_connection, + Float32(√out)) end function (l::TransformerConv)(g::GNNGraph, x::AbstractMatrix, - e::Union{AbstractMatrix, Nothing} = nothing) + e::Union{AbstractMatrix, Nothing} = nothing) check_num_nodes(g, x) if l.add_self_loops