diff --git a/Project.toml b/Project.toml index 008a5fdd1..b7361ed72 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -40,6 +41,7 @@ NNlib = "0.8" NNlibCUDA = "0.2" NearestNeighbors = "0.4" Reexport = "1" +SnoopPrecompile = "1" StatsBase = "0.33" julia = "1.7" diff --git a/invalidations.jl b/invalidations.jl new file mode 100644 index 000000000..6d41a2b12 --- /dev/null +++ b/invalidations.jl @@ -0,0 +1,55 @@ +using SnoopCompileCore + +invalidations = @snoopr begin + using GraphNeuralNetworks + using Flux + # using CUDA + # using Graphs + # using Random, Statistics, LinearAlgebra +end + +function workload() + num_graphs = 3 + gs = [rand_graph(5, 10) for _ in 1:num_graphs] + g = Flux.batch(gs) + x = rand(Float32, 4, g.num_nodes) + model = GNNChain(GCNConv(4 => 4, relu), + GCNConv(4 => 4), + GlobalPool(max), + Dense(4, 1)) + y = model(g, x) + # @assert size(y) == (1, num_graphs) +end + +tinf = @snoopi_deep begin + workload() +end + +using SnoopCompile +trees = invalidation_trees(invalidations) +staletrees = precompile_blockers(trees, tinf) + +@show length(uinvalidated(invalidations)) # show total invalidations + +show(trees[end]) # show the most invalidating method + +# Count number of children (number of invalidations per invalidated method) +n_invalidations = map(SnoopCompile.countchildren, trees) + +# (optional) plot the number of children per method invalidations +import Plots +Plots.plot( + 1:length(trees), + n_invalidations; + markershape=:circle, + xlabel="i-th method invalidation", + label="Number of children per method invalidations" +) + +# (optional) report invalidations summary +using PrettyTables # needed for `report_invalidations` to be defined +SnoopCompile.report_invalidations(; + invalidations, + process_filename = x -> last(split(x, ".julia/packages/")), + n_rows = 0, # no-limit (show all invalidations) + ) \ No newline at end of file diff --git a/src/GNNGraphs/gatherscatter.jl b/src/GNNGraphs/gatherscatter.jl index f7aeadf29..c1f781260 100644 --- a/src/GNNGraphs/gatherscatter.jl +++ b/src/GNNGraphs/gatherscatter.jl @@ -4,10 +4,10 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x) _gather(x::AbstractArray, i) = NNlib.gather(x, i) _gather(x::Nothing, i) = nothing -_scatter(aggr, src::Nothing, idx, n) = nothing -_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) -_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) -_scatter(aggr, src::Dict, idx, n) = Dict(k => _scatter(aggr, v, idx, n) for (k, v) in src) +_scatter(aggr::A, src::Nothing, idx, n) where A = nothing +_scatter(aggr::A, src::NamedTuple, idx, n) where A = map(s -> _scatter(aggr, s, idx, n), src) +_scatter(aggr::A, src::Tuple, idx, n) where A = map(s -> _scatter(aggr, s, idx, n), src) +_scatter(aggr::A, src::Dict, idx, n) where A = Dict(k => _scatter(aggr, v, idx, n) for (k, v) in src) function _scatter(aggr, src::AbstractArray, diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 06dfa178a..c02f10046 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -2,7 +2,7 @@ module GraphNeuralNetworks using Statistics: mean using LinearAlgebra, Random -using Base: tail +using Base: tail, Fix1, Fix2 using CUDA using Flux using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch @@ -11,6 +11,7 @@ using NNlib, NNlibCUDA using NNlib: scatter, gather using ChainRulesCore using Reexport +using SnoopPrecompile using SparseArrays, Graphs # not needed but if removed Documenter will complain include("GNNGraphs/GNNGraphs.jl") @@ -83,4 +84,8 @@ include("msgpass.jl") include("mldatasets.jl") include("deprecations.jl") +@precompile_all_calls begin + include("precompile.jl") +end + end diff --git a/src/msgpass.jl b/src/msgpass.jl index a11f0946c..01b0bdfbc 100644 --- a/src/msgpass.jl +++ b/src/msgpass.jl @@ -73,11 +73,11 @@ See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref). """ function propagate end -function propagate(f, g::GNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) - propagate(f, g, aggr, xi, xj, e) +function propagate(f::F, g::GNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) where F + propagate(f, g, aggr, xi, xj, e) end -function propagate(f, g::GNNGraph, aggr, xi, xj, e = nothing) +function propagate(f::F, g::GNNGraph, aggr, xi, xj, e = nothing) where F m = apply_edges(f, g, xi, xj, e) m̄ = aggregate_neighbors(g, aggr, m) return m̄ @@ -87,12 +87,12 @@ end # https://github.com/JuliaLang/julia/issues/15276 ## and zygote issues # https://github.com/FluxML/Zygote.jl/issues/1317 -function propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing, - e = nothing) - propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) +function propagate(f::F, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing, + e = nothing) where F + propagate(Fix1(f, l), g, aggr, xi, xj, e) end -function propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing) - propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) +function propagate(f::F, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing) where F + propagate(Fix1(f, l), g, aggr, xi, xj, e) end ## APPLY EDGES @@ -135,11 +135,11 @@ See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). """ function apply_edges end -function apply_edges(f, g::GNNGraph; xi = nothing, xj = nothing, e = nothing) +function apply_edges(f::F, g::GNNGraph; xi = nothing, xj = nothing, e = nothing) where F apply_edges(f, g, xi, xj, e) end -function apply_edges(f, g::GNNGraph, xi, xj, e = nothing) +function apply_edges(f::F, g::GNNGraph, xi, xj, e = nothing) where F check_num_nodes(g, xi) check_num_nodes(g, xj) check_num_edges(g, e) @@ -154,12 +154,12 @@ end # https://github.com/JuliaLang/julia/issues/15276 ## and zygote issues # https://github.com/FluxML/Zygote.jl/issues/1317 -function apply_edges(f, g::GNNGraph, l::GNNLayer; xi = nothing, xj = nothing, e = nothing) - apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e) +function apply_edges(f::F, g::GNNGraph, l::GNNLayer; xi = nothing, xj = nothing, e = nothing) where F + apply_edges(Fix1(f, l), g, xi, xj, e) end -function apply_edges(f, g::GNNGraph, l::GNNLayer, xi, xj, e = nothing) - apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e) +function apply_edges(f::F, g::GNNGraph, l::GNNLayer, xi, xj, e = nothing) where F + apply_edges(Fix1(f, l), g, xi, xj, e) end ## AGGREGATE NEIGHBORS @@ -176,7 +176,7 @@ features Neighborhood aggregation is the second step of [`propagate`](@ref), where it comes after [`apply_edges`](@ref). """ -function aggregate_neighbors(g::GNNGraph, aggr, m) +function aggregate_neighbors(g::GNNGraph, aggr::A, m) where {A} check_num_edges(g, m) s, t = edge_index(g) return GNNGraphs._scatter(aggr, m, t, g.num_nodes) diff --git a/src/precompile.jl b/src/precompile.jl new file mode 100644 index 000000000..044b4f62a --- /dev/null +++ b/src/precompile.jl @@ -0,0 +1,16 @@ + +function workflow1() + nnodes, d = 10, 6 + ngraphs = 5 + g = Flux.batch([rand_graph(nnodes, 3*nnodes) for i in 1:ngraphs]) + x = rand(Float32, d, g.num_nodes) + model = GNNChain(GCNConv(d => d, relu), + GraphConv(d => d, tanh), + GATv2Conv(d => d ÷ 2, relu, heads=2), + GlobalPool(max), + Dense(d, 1)) + y = model(g, x) + grad = gradient(m -> sum(m(g, x)), model)[1] +end + +workflow1() \ No newline at end of file