diff --git a/.gitignore b/.gitignore index 3d1804049..fe8ad6f3a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ Manifest.toml LocalPreferences.toml .DS_Store /test.jl +try.jl \ No newline at end of file diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 19c200e7a..eb80de453 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1137,6 +1137,7 @@ end function (l::CGConv)(g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + print("CG G TYPES", length(g.etypes), g.etypes) check_num_nodes(g, x) xj, xi = expand_srcdst(g, x) @@ -1233,18 +1234,23 @@ function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true) AGNNConv([init_beta], add_self_loops, trainable) end -function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix) +function (l::AGNNConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) + xi, xj = expand_srcdst(g, x) + edge_t = g isa GNNHeteroGraph ? g.etypes[1] : nothing + if l.add_self_loops - g = add_self_loops(g) + g = g isa GNNHeteroGraph ? add_self_loops(g, edge_t) : add_self_loops(g) end - xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) - cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) + xi_n = xi ./ sqrt.(sum(xi .^ 2, dims = 2)) + xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 2)) + + cos_dist = apply_edges(xi_dot_xj, g, xi = xi_n, xj = xj_n) α = softmax_edge_neighbors(g, l.β .* cos_dist) - x = propagate(g, +; xj = x, e = α) do xi, xj, α - α .* xj + x = propagate(g, +; xi = xi_n, xj = xj_n, e = α) do xi, xj, α + α .* xj .* xi end return x diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 3d5f2c09c..73fffc810 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -125,6 +125,16 @@ @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end + @testset "AGNNConv" begin + x = (A = rand(Float32, 2, 2), B = rand(Float32, 3, 3)) + + layers = HeteroGraphConv((:A, :to, :B) => AGNNConv(init_beta=1.0, add_self_loops=true, trainable=true), + (:B, :to, :A) => AGNNConv(init_beta=1.0, add_self_loops=true, trainable=true)); + + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + @testset "GINConv" begin x = (A = rand(4, 2), B = rand(4, 3)) layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4),