From 0a80d8f70deecd8ea2f11f9ea090095328c0461f Mon Sep 17 00:00:00 2001 From: ST John Date: Tue, 2 May 2023 19:34:44 +0300 Subject: [PATCH 1/5] add enzyme to tests --- test/Project.toml | 1 + test/runtests.jl | 1 + test/test_utils.jl | 25 ++++++++++++++++++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index e16f39a6c..d3f8a43a7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..8ce4443d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ using Zygote: Zygote using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences +using Enzyme: Enzyme using Compat: only using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils diff --git a/test/test_utils.jl b/test/test_utils.jl index 8367fbd6d..f22f0c4af 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -47,8 +47,9 @@ gradient(f, s::Symbol, args) = gradient(f, Val(s), args) function gradient(f, ::Val{:Zygote}, args) g = only(Zygote.gradient(f, args)) if isnothing(g) + # To respect the same output as other ADs if args isa AbstractArray{<:Real} - return zeros(size(args)) # To respect the same output as other ADs + return zeros(size(args)) else return zeros.(size.(args)) end @@ -57,6 +58,24 @@ function gradient(f, ::Val{:Zygote}, args) end end +function gradient(f, ::Val{:EnzymeForward}, args) + # shape = size(args) + # f_prime(flatargs) = f(reshape(flatargs, shape...)) + # return Enzyme.gradient(Enzyme.Forward, f_prime, reshape(args, prod(shape))) + d_args = zero(args) + Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Active, Enzyme.Duplicated(args, d_args)) + return d_args +end + +function gradient(f, ::Val{:EnzymeReverse}, args) + # shape = size(args) + # f_prime(flatargs) = f(reshape(flatargs, shape...)) + # return Enzyme.gradient(Enzyme.Reverse, f_prime, reshape(args, prod(shape))) + d_args = zero(args) + Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active, Enzyme.Duplicated(args, d_args)) + return d_args +end + function gradient(f, ::Val{:ForwardDiff}, args) return ForwardDiff.gradient(f, args) end @@ -90,7 +109,7 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A)) testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) function test_ADs( - kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] + kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], dims=[3, 3] ) test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) if !test_fd.anynonpass @@ -108,7 +127,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) end function test_ADs( - k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3) + k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], dims=(in=3, out=2, obs=3) ) test_fd = test_FiniteDiff(k, dims) if !test_fd.anynonpass From 98e4ce8a08219c27838b60c81eb51a4c25fda8bc Mon Sep 17 00:00:00 2001 From: st-- Date: Tue, 2 May 2023 20:10:06 +0300 Subject: [PATCH 2/5] Update test/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/test_utils.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index f22f0c4af..e84dbd550 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -109,7 +109,10 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A)) testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) function test_ADs( - kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], dims=[3, 3] + kernelfunction, + args=nothing; + ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], + dims=[3, 3], ) test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) if !test_fd.anynonpass From 18ed6f898db14a16aaf3734b800d6e48f90898b3 Mon Sep 17 00:00:00 2001 From: st-- Date: Tue, 2 May 2023 20:10:12 +0300 Subject: [PATCH 3/5] Update test/test_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/test_utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index e84dbd550..a8c96a1ae 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -130,7 +130,9 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) end function test_ADs( - k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], dims=(in=3, out=2, obs=3) + k::MOKernel; + ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], + dims=(in=3, out=2, obs=3), ) test_fd = test_FiniteDiff(k, dims) if !test_fd.anynonpass From 6dc465f9fb5d7a3fef2550a30b3f422ac26a7352 Mon Sep 17 00:00:00 2001 From: ST John Date: Tue, 2 May 2023 22:05:59 +0300 Subject: [PATCH 4/5] fix gradient(f, ::Val{EnzymeForward}, args) --- test/test_utils.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index f22f0c4af..4562f5b6a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -59,12 +59,7 @@ function gradient(f, ::Val{:Zygote}, args) end function gradient(f, ::Val{:EnzymeForward}, args) - # shape = size(args) - # f_prime(flatargs) = f(reshape(flatargs, shape...)) - # return Enzyme.gradient(Enzyme.Forward, f_prime, reshape(args, prod(shape))) - d_args = zero(args) - Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Active, Enzyme.Duplicated(args, d_args)) - return d_args + return Enzyme.gradient(Enzyme.Forward, f, args) end function gradient(f, ::Val{:EnzymeReverse}, args) From 858dd663eb6a5993bf14d1a2306c75a16c0baf23 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 3 May 2023 11:57:57 +0300 Subject: [PATCH 5/5] remove EnzymeForward from default AD tests - crashes julia --- test/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index 52d25bebb..744d3b6c2 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -106,7 +106,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) function test_ADs( kernelfunction, args=nothing; - ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], + ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse], dims=[3, 3], ) test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) @@ -126,7 +126,7 @@ end function test_ADs( k::MOKernel; - ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], + ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse], dims=(in=3, out=2, obs=3), ) test_fd = test_FiniteDiff(k, dims)