Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f1d680a

Browse files
authored
Merge pull request #75 from yuehhua/gno
Fix GraphKernel
2 parents 88d6f7c + cf34d62 commit f1d680a

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

src/graph_kernel.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,22 @@ end
2828

2929
Flux.@functor GraphKernel
3030

31-
function GeometricFlux.message(l::GraphKernel, x_i::AbstractArray, x_j::AbstractArray, e_ij)
32-
return l.κ(vcat(x_i, x_j))
31+
function GeometricFlux.message(l::GraphKernel, x_i, x_j::AbstractArray, e_ij::AbstractArray)
32+
N = size(x_j, 1)
33+
K = l.κ(e_ij)
34+
dims = size(K)[2:end]
35+
m_ij = GeometricFlux._matmul(reshape(K, N, N, :), reshape(x_j, N, 1, :))
36+
return reshape(m_ij, N, dims...)
3337
end
3438

3539
function GeometricFlux.update(l::GraphKernel, m::AbstractArray, x::AbstractArray)
3640
return l.σ.(GeometricFlux._matmul(l.linear, x) + m)
3741
end
3842

39-
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray)
43+
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
4044
GraphSignals.check_num_nodes(el.N, X)
41-
_, V, _ = GeometricFlux.propagate(l, el, nothing, X, nothing, mean, nothing, nothing)
45+
GraphSignals.check_num_nodes(el.E, E)
46+
_, V, _ = GeometricFlux.propagate(l, el, E, X, nothing, mean, nothing, nothing)
4247
return V
4348
end
4449

test/graph_kernel.jl

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
@testset "GraphKernel" begin
22
batch_size = 5
33
channel = 32
4-
N = 10 * 10
4+
coord_dim = 2
5+
N = 10
56

6-
κ = Dense(2 * channel, channel, relu)
7+
graph = grid([N, N])
8+
κ = Dense(2(coord_dim + 1), abs2(channel), relu)
79

8-
graph = grid([10, 10])
9-
𝐱 = rand(Float32, channel, N, batch_size)
10+
𝐱 = rand(Float32, channel, nv(graph), batch_size)
11+
E = rand(Float32, 2(coord_dim + 1), ne(graph), batch_size)
1012
l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
11-
@test repr(l.layer) == "GraphKernel(Dense(64 => 32, relu), channel=32)"
12-
@test size(l(𝐱)) == (channel, N, batch_size)
13+
@test repr(l.layer) ==
14+
"GraphKernel(Dense($(2(coord_dim + 1)) => $(abs2(channel)), relu), channel=32)"
15+
@test size(l(𝐱, E)) == (channel, nv(graph), batch_size)
1316

14-
g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l))
17+
g = Zygote.gradient(() -> sum(l(𝐱, E)), Flux.params(l))
1518
@test length(g.grads) == 3
1619
end

0 commit comments

Comments
 (0)