Skip to content

Commit 4c40328

Browse files
committed
Partial first, derivative at concept
1 parent 917c4da commit 4c40328

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

src/diffKernel.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,36 @@ import ForwardDiff as FD
22
import LinearAlgebra as LA
33
using KernelFunctions: SimpleKernel, Kernel
44

5+
const DiffPt = Tuple{Partial,Vararg} # allow for one dimensional and MultiOutput kernels
56
"""
6-
diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {Dim, T<:Kernel}
7+
diffKernelCall(k::T, (px,x)::DiffPt, (py, y)::DiffPt) where {Dim, T<:Kernel}
78
89
specialization for DiffPt. Unboxes the partial instructions from DiffPt and
910
applies them to k, evaluates them at the positions of DiffPt
1011
"""
11-
function diffKernelCall(k::T, (x,px)::DiffPt, (y,py)::DiffPt) where {T<:Kernel}
12+
function diffKernelCall(k::T, (px, x)::Tuple{Partial,Pos1}, (py, y)::Tuple{Partial,Pos2}) where {T<:Kernel,Pos1,Pos2}
13+
# need Pos1 and Pos2 because k(1,1.) is allowed (combination of Int and Float) mabye there is a better solution resulting in more type safety?
1214
return apply_partial(k, px.indices, py.indices)(x, y)
1315
end
16+
"""
17+
Multi Kernel Version (do not try to take the derivative with regard to out indices)
18+
"""
19+
function diffKernelCall(
20+
k::T,
21+
(px, x, x_out)::Tuple{Partial,Pos1,Idx1},
22+
(py, y, y_out)::Tuple{Partial,Pos2,Idx2}
23+
) where {T<:MOKernel,Pos1,Idx1,Pos2,Idx2}
24+
return apply_partial((x, y) -> k((x, x_out), (y, y_out)), px.indices, py.indices)(x, y)
25+
end
1426

1527
"""
1628
EnableDiff
1729
1830
A thin wrapper around Kernels enabling the machinery which allows you to
19-
input (x, ∂ᵢ), (y, ∂ⱼ) where ∂ᵢ, ∂ⱼ are of `Partial` type (see [partial](@ref)) in order
31+
input (∂ᵢ, x), (∂ⱼ, y) where ∂ᵢ, ∂ⱼ are of `Partial` type (see [partial](@ref)) in order
2032
to calculate
2133
``
22-
k((x, ∂ᵢ), (y,∂ⱼ)) = \\text{Cov}(\\partial_i Z(x), \\partial_j Z(y))
34+
k((∂ᵢ, x), (∂ⱼ, y)) = \\text{Cov}(\\partial_i Z(x), \\partial_j Z(y))
2335
``
2436
for ``Z`` with ``k(x,y) = \\text{Cov}(Z(x), Z(y))``.
2537
@@ -47,7 +59,7 @@ struct EnableDiff{T<:Kernel} <: Kernel
4759
kernel::T
4860
end
4961
(k::EnableDiff)(x::DiffPt, y::DiffPt) = diffKernelCall(k.kernel, x, y)
50-
(k::EnableDiff)(x::DiffPt, y) = diffKernelCall(k.kernel, x,(y, partial()))
51-
(k::EnableDiff)(x, y::DiffPt) = diffKernelCall(k.kernel, (x, partial()), y)
52-
(k::EnableDiff)(x, y) = k.kernel(x,y) # Fall through case
62+
(k::EnableDiff)(x::DiffPt, y) = diffKernelCall(k.kernel, x, (partial(), y))
63+
(k::EnableDiff)(x, y::DiffPt) = diffKernelCall(k.kernel, (partial(), x), y)
64+
(k::EnableDiff)(x, y) = k.kernel(x, y) # Fall through case
5365

src/partial.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ for T in [MIME"text/plain", MIME"text/html"]
4949
end
5050
end
5151

52-
const DiffPt{T} = Tuple{T,Partial}
5352

5453
function fullderivative(::Val{order}, input_indices::AbstractVector{Int}) where {order}
5554
return mappedarray(partial, productArray(ntuple(_ -> input_indices, Val{order}())...))
@@ -65,6 +64,15 @@ gradient(dim::Integer) = fullderivative(Val(1), dim)
6564
hessian(input_indices::AbstractArray) = fullderivative(Val(2), input_indices)
6665
hessian(dim::Integer) = fullderivative(Val(2), dim)
6766

67+
diffAt(::Val{order}, x) where {order} = productArray(Ref(x), _diffAt(Base.IteratorSize(x), Val(order), x))
68+
_diffAt(::Base.HasLength, ::Val{order}, x) where {order} = fullderivative(Val(order), Base.OneTo(length(x)))
69+
_diffAt(::Base.HasShape{1}, ::Val{order}, x) where {order} = fullderivative(Val(order), Base.OneTo(length(x)))
70+
_diffAt(::Base.HasShape, ::Val{order}, x) where {order} = fullderivative(Val(order), CartesianIndices(axes(x)))
71+
72+
gradAt(x) = diffAt(Val(1), x)
73+
grad(f) = x -> f.(gradAt(x)) # for f = rand(::GP), grad(f)(x) should work.
74+
75+
6876
# idea: lazy mappings can be undone (extract original range -> towards a specialization speedup of broadcasting over multiple derivatives using backwardsdiff)
6977
const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1,Tuple{Int}},1,T,typeof(partial)}
7078
function extract_range(p_map::MappedPartialVec{T}) where {T<:AbstractUnitRange{Int}}

test/diffKernel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
k = EnableDiff(MaternKernel())
44
k2 = MaternKernel()
55
@test k(1, 1) == k2(1, 1)
6-
k(1, (1, partial(1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1
7-
k(([1], partial(1)), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2]
6+
k(1, (partial(1, 1), 1)) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1
7+
k((partial(1), [1]), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2]
88
end
99

1010
@testset "Sanity Checks with $k1" for k1 in [
@@ -19,19 +19,19 @@
1919
## This fails for Matern and RationalQuadraticKernel
2020
# because its implementation branches on x == y resulting in a zero derivative
2121
# (cf. https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/517)
22-
@test k((x, partial(1)), (x, partial(1))) > 0
22+
@test k((partial(1), x), (partial(1), x)) > 0
2323

2424
# the slope should be positively correlated with a point further down
2525
@test k(
26-
(x, partial(1)), # slope
26+
(partial(1), x), # slope
2727
x + 1e-2, # point further down
2828
) > 0
2929

3030
@testset "Stationary Tests" begin
31-
@test k((x, partial(1)), x) == 0 # expect Cov(∂Z(x) , Z(x)) == 0
31+
@test k((partial(1), x), x) == 0 # expect Cov(∂Z(x) , Z(x)) == 0
3232

3333
@testset "Isotropic Tests" begin
34-
@test k(([1, 2], partial(1)), ([1, 2], partial(2))) == 0 # cross covariance should be zero
34+
@test k((partial(1), [1, 2]), (partial(2), [1, 2])) == 0 # cross covariance should be zero
3535
end
3636
end
3737
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel, RationalQuadraticKernel
1+
using KernelFunctions: KernelFunctions as KF, MaternKernel, SEKernel, RationalQuadraticKernel, Matern32Kernel, Matern52Kernel
22
using DifferentiableKernelFunctions: DifferentiableKernelFunctions as DKF, EnableDiff, partial
33
using ProductArrays: productArray
44
using Test

0 commit comments

Comments
 (0)