diff --git a/GNNGraphs/test/Project.toml b/GNNGraphs/test/Project.toml index f18c35628..17a8dda21 100644 --- a/GNNGraphs/test/Project.toml +++ b/GNNGraphs/test/Project.toml @@ -2,7 +2,6 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -21,5 +20,3 @@ TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -[compat] -GPUArraysCore = "0.1" diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 43cddbe8a..6c0676a3a 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -41,12 +41,12 @@ export AGNNConv, # TransformerConv include("layers/temporalconv.jl") -export TGCN, - A3TGCN, - GConvGRU, - GConvLSTM, - DCGRU, - EvolveGCNO +export GNNRecurrence, + GConvGRU, GConvGRUCell, + GConvLSTM, GConvLSTMCell, + DCGRU, DCGRUCell, + EvolveGCNO, EvolveGCNOCell, + TGCN, TGCNCell include("layers/pool.jl") export GlobalPool, diff --git a/GNNLux/src/layers/basic.jl b/GNNLux/src/layers/basic.jl index ea0df2908..002e6a122 100644 --- a/GNNLux/src/layers/basic.jl +++ b/GNNLux/src/layers/basic.jl @@ -10,6 +10,8 @@ abstract type GNNLayer <: AbstractLuxLayer end abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end +const AbstractGNNLayer = Union{GNNLayer, GNNContainerLayer} + """ GNNChain(layers...) GNNChain(name = layer, ...) @@ -104,3 +106,22 @@ _applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;) _applylayer(l::AbstractLuxLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) _applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) _applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) + + +# Facilitate using GNNlib functions with Lux layers +# by returning a StatefulLuxLayer when accessing properties +function Base.getproperty(l::StatefulLuxLayer{ST,<:AbstractGNNLayer}, name::Symbol) where ST + hasfield(typeof(l), name) && return getfield(l, name) + f = getproperty(l.model, name) + if f isa AbstractLuxLayer + stf = getproperty(Lux.get_state(l), name) + psf = getproperty(l.ps, name) + if ST === Static.True + return StatefulLuxLayer{true}(f, psf, stf) + else + return StatefulLuxLayer{false}(f, psf, stf) + end + else + return f + end +end diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index d6036cd8f..8b5b0dc9b 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -421,20 +421,16 @@ function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = gl return DCGRUCell(in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state) end -function (l::DCGRUCell)(g, (x, h), ps, st) - if h === nothing - h = l.init_state(l.out_dims, g.num_nodes) - end - h̃ = vcat(x, h) - z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u) - z = NNlib.sigmoid_fast.(z) - r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r) - r = NNlib.sigmoid_fast.(r) - ĥ = vcat(x, h .* r) - c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c) - c = NNlib.tanh_fast.(c) - h = z.* h + (1 .- z).* c - return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c) + +function (l::DCGRUCell)(g, x::AbstractMatrix, ps, st) + h = l.init_state(l.out_dims, g.num_nodes) + return l(g, (x, (h,)), ps, st) +end + +function (l::DCGRUCell)(g, (x, (h,))::Tuple, ps, st) + m = StatefulLuxLayer{true}(l, ps, st) + h, _ = GNNlib.dcgrucell_frwd(m, g, x, h) + return (h, (h,)), _getstate(m) end function Base.show(io::IO, l::DCGRUCell) diff --git a/GNNLux/test/Project.toml b/GNNLux/test/Project.toml index d651808bf..46f8e45a3 100644 --- a/GNNLux/test/Project.toml +++ b/GNNLux/test/Project.toml @@ -13,6 +13,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/GNNLux/test/test_module.jl b/GNNLux/test/test_module.jl index 83f3e7062..8e99d41d5 100644 --- a/GNNLux/test/test_module.jl +++ b/GNNLux/test/test_module.jl @@ -1,77 +1,5 @@ @testmodule TestModuleLux begin -using Pkg - -## Uncomment below to change the default test settings -# ENV["GNN_TEST_CUDA"] = "true" -# ENV["GNN_TEST_AMDGPU"] = "true" -# ENV["GNN_TEST_Metal"] = "true" - -to_test(backend) = get(ENV, "GNN_TEST_$(backend)", "false") == "true" -has_dependecies(pkgs) = all(pkg -> haskey(Pkg.project().dependencies, pkg), pkgs) -deps_dict = Dict(:CUDA => ["CUDA", "cuDNN"], :AMDGPU => ["AMDGPU"], :Metal => ["Metal"]) - -for (backend, deps) in deps_dict - if to_test(backend) - if !has_dependecies(deps) - Pkg.add(deps) - end - @eval using $backend - if backend == :CUDA - @eval using cuDNN - end - @eval $backend.allowscalar(false) - end -end - -using Reexport: @reexport - -@reexport using Test -@reexport using GNNLux -@reexport using Lux -@reexport using StableRNGs -@reexport using Random, Statistics - -using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme - -export test_lux_layer - -function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; - outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2, e=nothing) - - if container - @test l isa GNNContainerLayer - else - @test l isa GNNLayer - end - - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) - @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) - @test LuxCore.statelength(l) == LuxCore.statelength(st) - - if e !== nothing - y, st′ = l(g, x, e, ps, st) - else - y, st′ = l(g, x, ps, st) - end - @test eltype(y) == eltype(x) - if outputsize !== nothing - @test LuxCore.outputsize(l) == outputsize - end - if sizey !== nothing - @test size(y) == sizey - elseif outputsize !== nothing - @test size(y) == (outputsize..., g.num_nodes) - end - - if e !== nothing - loss = (x, ps) -> sum(first(l(g, x, e, ps, st))) - else - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - end - test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) -end +include("test_utils.jl") end diff --git a/GNNLux/test/test_utils.jl b/GNNLux/test/test_utils.jl new file mode 100644 index 000000000..1d6ac4e1b --- /dev/null +++ b/GNNLux/test/test_utils.jl @@ -0,0 +1,74 @@ +using Pkg + +## Uncomment below to change the default test settings +# ENV["GNN_TEST_CUDA"] = "true" +# ENV["GNN_TEST_AMDGPU"] = "true" +# ENV["GNN_TEST_Metal"] = "true" + +to_test(backend) = get(ENV, "GNN_TEST_$(backend)", "false") == "true" +has_dependecies(pkgs) = all(pkg -> haskey(Pkg.project().dependencies, pkg), pkgs) +deps_dict = Dict(:CUDA => ["CUDA", "cuDNN"], :AMDGPU => ["AMDGPU"], :Metal => ["Metal"]) + +for (backend, deps) in deps_dict + if to_test(backend) + if !has_dependecies(deps) + Pkg.add(deps) + end + @eval using $backend + if backend == :CUDA + @eval using cuDNN + end + @eval $backend.allowscalar(false) + end +end + +using Reexport: @reexport + +@reexport using Test +@reexport using GNNLux +@reexport using Lux +@reexport using StableRNGs +@reexport using Random, Statistics + +using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + +export test_lux_layer + +function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; + outputsize=nothing, sizey=nothing, container=false, + atol=1.0f-2, rtol=1.0f-2, e=nothing) + + if container + @test l isa GNNContainerLayer + else + @test l isa GNNLayer + end + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) + @test LuxCore.statelength(l) == LuxCore.statelength(st) + + if e !== nothing + y, st′ = l(g, x, e, ps, st) + else + y, st′ = l(g, x, ps, st) + end + @test eltype(y) == eltype(x) + if outputsize !== nothing + @test LuxCore.outputsize(l) == outputsize + end + if sizey !== nothing + @test size(y) == sizey + elseif outputsize !== nothing + @test size(y) == (outputsize..., g.num_nodes) + end + + if e !== nothing + loss = (x, ps) -> sum(first(l(g, x, e, ps, st))) + else + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + end + test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) +end + diff --git a/GNNlib/src/GNNlib.jl b/GNNlib/src/GNNlib.jl index 1941f5752..f61004912 100644 --- a/GNNlib/src/GNNlib.jl +++ b/GNNlib/src/GNNlib.jl @@ -61,7 +61,12 @@ export agnn_conv, transformer_conv include("layers/temporalconv.jl") -export tgcn_conv +export a3tgcn_conv, + dcgrucell_frwd, + evolvegcnocell_frwd, + gconvgrucell_frwd, + gconvlstmcell_frwd, + tgcn_frwd include("layers/pool.jl") export global_pool, diff --git a/GNNlib/src/layers/temporalconv.jl b/GNNlib/src/layers/temporalconv.jl index 8cff3f033..b43c4eba7 100644 --- a/GNNlib/src/layers/temporalconv.jl +++ b/GNNlib/src/layers/temporalconv.jl @@ -10,3 +10,66 @@ function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) return c end + +function gconvgrucell_frwd(cell, g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + # reset gate + r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) + r = NNlib.sigmoid_fast(r) + # update gate + z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) + z = NNlib.sigmoid_fast(z) + # new gate + h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) + h̃ = NNlib.tanh_fast(h̃) + h = (1 .- z) .* h̃ .+ z .* h + return h, h +end + + +function gconvlstmcell_frwd(cell, g::GNNGraph, x::AbstractMatrix, (h, c)) + # input gate + i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i + i = NNlib.sigmoid_fast(i) + # forget gate + f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f + f = NNlib.sigmoid_fast(f) + # cell state + c = f .* c .+ i .* NNlib.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) + # output gate + o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o + o = NNlib.sigmoid_fast(o) + h = o .* NNlib.tanh_fast(c) + return h, (h, c) +end + +function dcgrucell_frwd(cell, g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + h̃ = vcat(x, h) + z = cell.dconv_u(g, h̃) + z = NNlib.sigmoid_fast.(z) + r = cell.dconv_r(g, h̃) + r = NNlib.sigmoid_fast.(r) + ĥ = vcat(x, h .* r) + c = cell.dconv_c(g, ĥ) + c = NNlib.tanh_fast.(c) + h = z.* h + (1 .- z) .* c + return h, h +end + + +function evolvegcnocell_frwd(cell, g::GNNGraph, x::AbstractMatrix, state) + weight, state_lstm = cell.lstm(state.weight, state.lstm) + x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) + return x, (; weight, lstm = state_lstm) +end + + +function tgcncell_frwd(cell, g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + z = cell.conv_z(g, x) + z = cell.dense_z(vcat(z, h)) + r = cell.conv_r(g, x) + r = cell.dense_r(vcat(r, h)) + h̃ = cell.conv_h(g, x) + h̃ = cell.dense_h(vcat(h̃, r .* h)) + h = (1 .- z) .* h .+ z .* h̃ + return h, h +end \ No newline at end of file diff --git a/GNNlib/test/Project.toml b/GNNlib/test/Project.toml index 36fcae23b..19f547aed 100644 --- a/GNNlib/test/Project.toml +++ b/GNNlib/test/Project.toml @@ -4,7 +4,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -19,4 +18,3 @@ TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -GPUArraysCore = "0.1" diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 8db7bbc38..1e9307292 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -1,3 +1,6 @@ +# Temporal Convolutional Layers for Graph Neural Networks +# Implementations are found in GNNlib + function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} y = [] for xt in eachslice(x, dims = 2) @@ -241,17 +244,7 @@ function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) end function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) - # reset gate - r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) - r = Flux.sigmoid_fast(r) - # update gate - z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) - z = Flux.sigmoid_fast(z) - # new gate - h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) - h̃ = Flux.tanh_fast(h̃) - h = (1 .- z) .* h̃ .+ z .* h - return h, h + return GNNlib.gconvgrucell_frwd(cell, g, x, h) end function Base.show(io::IO, cell::GConvGRUCell) @@ -422,19 +415,7 @@ function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c)) c = repeat(c, 1, g.num_nodes) end @assert ndims(h) == 2 && ndims(c) == 2 - # input gate - i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i - i = Flux.sigmoid_fast(i) - # forget gate - f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f - f = Flux.sigmoid_fast(f) - # cell state - c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) - # output gate - o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o - o = Flux.sigmoid_fast(o) - h = o .* Flux.tanh_fast(c) - return h, (h, c) + return GNNlib.gconvlstmcell_frwd(cell, g, x, (h, c)) end function Base.show(io::IO, cell::GConvLSTMCell) @@ -535,7 +516,7 @@ julia> size(y) # (d_out, num_nodes) (3, 5) ``` """ -struct DCGRUCell +struct DCGRUCell <: GNNLayer in::Int out::Int k::Int @@ -563,16 +544,7 @@ function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) end function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) - h̃ = vcat(x, h) - z = cell.dconv_u(g, h̃) - z = NNlib.sigmoid_fast.(z) - r = cell.dconv_r(g, h̃) - r = NNlib.sigmoid_fast.(r) - ĥ = vcat(x, h .* r) - c = cell.dconv_c(g, ĥ) - c = NNlib.tanh_fast.(c) - h = z.* h + (1 .- z) .* c - return h, h + return GNNlib.dcgrucell_frwd(cell, g, x, h) end function Base.show(io::IO, cell::DCGRUCell) @@ -700,9 +672,7 @@ end (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state) - weight, state_lstm = cell.lstm(state.weight, state.lstm) - x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) - return x, (; weight, lstm = state_lstm) + return GNNlib.evolvegcnocell_frwd(cell, g, x, state) end function Base.show(io::IO, egcno::EvolveGCNOCell) @@ -845,14 +815,7 @@ function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) end function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) - z = cell.conv_z(g, x) - z = cell.dense_z(vcat(z, h)) - r = cell.conv_r(g, x) - r = cell.dense_r(vcat(r, h)) - h̃ = cell.conv_h(g, x) - h̃ = cell.dense_h(vcat(h̃, r .* h)) - h = (1 .- z) .* h .+ z .* h̃ - return h, h + return GNNlib.tgcncell_frwd(cell, g, x, h) end function Base.show(io::IO, cell::TGCNCell) diff --git a/GraphNeuralNetworks/test/Project.toml b/GraphNeuralNetworks/test/Project.toml index ebdb52172..1406ef2bf 100644 --- a/GraphNeuralNetworks/test/Project.toml +++ b/GraphNeuralNetworks/test/Project.toml @@ -5,7 +5,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" @@ -18,4 +17,3 @@ TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -GPUArraysCore = "0.1"