Skip to content

Commit 8e805ef

Browse files
authored
fix MaternKernel AD, but remove differentiation wrt \nu (#425)
* reactivate test_ADs of MaternKernel, but do not test differentiation through \nu argument * make clear in docstring that MaternKernel does not support derivative w.r.t. order \nu * work-around for Zygote AD that will return nothing/NaN gradients w.r.t. \nu but fixes the gradient w.r.t. x
1 parent 873aa8d commit 8e805ef

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.36"
3+
version = "0.10.37"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/basekernels/matern.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
1717
A Gaussian process with a Matérn kernel is ``\\lceil \\nu \\rceil - 1``-times
1818
differentiable in the mean-square sense.
1919
20+
!!! note
21+
22+
Differentiation with respect to the order ν is not currently supported.
23+
2024
See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref)
2125
"""
2226
struct MaternKernel{Tν<:Real,M} <: SimpleKernel
@@ -33,8 +37,11 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
3337

3438
@functor MaternKernel
3539

36-
@inline function kappa::MaternKernel, d::Real)
37-
result = _matern(only.ν), d)
40+
@inline _get_ν(k::MaternKernel) = only(k.ν)
41+
ChainRulesCore.@non_differentiable _get_ν(k) # work-around; should be "NotImplemented" rather than NoTangent
42+
43+
@inline function kappa(k::MaternKernel, d::Real)
44+
result = _matern(_get_ν(k), d)
3845
return ifelse(iszero(d), one(result), result)
3946
end
4047

src/matrix/kernelkroneckermat.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ where `D` is given by `dims`.
1414
1515
!!! warning
1616
17-
Require `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
17+
Requires `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
1818
"""
1919
function kernelkronmat::Kernel, X::AbstractVector{<:Real}, dims::Int)
2020
checkkroncompatible(κ)

test/basekernels/matern.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,24 @@
44
v1 = rand(rng, 3)
55
v2 = rand(rng, 3)
66
@testset "MaternKernel" begin
7-
ν = 2.0
7+
ν = 2.1
88
k = MaternKernel(; ν=ν)
99
matern(x, ν) = 2^(1 - ν) / gamma(ν) * (sqrt(2ν) * x)^ν * besselk(ν, sqrt(2ν) * x)
1010
@test MaternKernel(; nu=ν).ν == [ν]
1111
@test kappa(k, x) matern(x, ν)
1212
@test kappa(k, 0.0) == 1.0
13-
@test kappa(MaternKernel(; ν=ν), x) == kappa(k, x)
1413
@test metric(MaternKernel()) == Euclidean()
1514
@test metric(MaternKernel(; ν=2.0)) == Euclidean()
1615
@test repr(k) == "Matern Kernel (ν = $(ν), metric = Euclidean(0.0))"
17-
# test_ADs(x->MaternKernel(nu=first(x)),[ν])
18-
@test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)"
1916

2017
k2 = MaternKernel(; ν=ν, metric=WeightedEuclidean(ones(3)))
2118
@test metric(k2) isa WeightedEuclidean
2219
@test k2(v1, v2) k(v1, v2)
2320

2421
# Standardised tests.
2522
TestUtils.test_interface(k, Float64)
23+
test_ADs(() -> MaternKernel(; nu=ν))
24+
2625
test_params(k, ([ν],))
2726
end
2827
@testset "Matern32Kernel" begin

0 commit comments

Comments
 (0)