-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathdiffKernel.jl
101 lines (89 loc) · 3.02 KB
/
diffKernel.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import ForwardDiff as FD
import LinearAlgebra as LA
"""
DiffPt(x; partial=())
For a covariance kernel k of GP Z, i.e.
```julia
k(x,y) # = Cov(Z(x), Z(y)),
```
a DiffPt allows the differentiation of Z, i.e.
```julia
k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y))
```
for higher order derivatives partial can be any iterable, i.e.
```julia
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
```
"""
IndexType = Union{Int,Base.AbstractCartesianIndex}
struct DiffPt{Order,KeyT<:IndexType,T}
pos::T # the actual position
partials::NTuple{Order,KeyT}
end
DiffPt(x::T) where {T<:AbstractArray} = DiffPt{0,keytype(T),T}(x, ()::NTuple{0,keytype(T)})
DiffPt(x::T) where {T<:Number} = DiffPt{0,Int,T}(x, ()::NTuple{0,Int})
DiffPt(x::T, partial::IndexType) where {T} = DiffPt{1,IndexType,T}(x, (partial,))
function DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT}
return DiffPt{Order,KeyT,T}(x, partials)
end
"""
tangentCurve(x₀, i::IndexType)
returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index i
"""
function tangentCurve(x0::AbstractArray, idx::IndexType)
return t -> begin
x = similar(x0, promote_type(eltype(x0), typeof(t)))
copyto!(x, x0)
x[idx] += t
return x
end
end
function tangentCurve(x0::Number, ::IndexType)
return t -> x0 + t
end
partial(func) = func
function partial(func, idx::IndexType)
return x -> FD.derivative(func ∘ tangentCurve(x, idx), 0)
end
function partial(func, partials::IndexType...)
idx, state = iterate(partials)
return partial(
x -> FD.derivative(func ∘ tangentCurve(x, idx), 0), Base.rest(partials, state)...
)
end
"""
Take the partial derivative of a function with two dim-dimensional inputs,
i.e. 2*dim dimensional input
"""
function partial(
k, partials_x::Tuple{Vararg{T}}, partials_y::Tuple{Vararg{T}}
) where {T<:IndexType}
local f(x, y) = partial(t -> k(t, y), partials_x...)(x)
return (x, y) -> partial(t -> f(x, t), partials_y...)(y)
end
"""
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
generics are not allowed in the syntax above by the dispatch system, this
redirection over `_evaluate` is necessary
unboxes the partial instructions from DiffPt and applies them to k,
evaluates them at the positions of DiffPt
"""
function _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
return partial(k, x.partials, y.partials)(x.pos, y.pos)
end
#=
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
```julia
(::Kernel)(::DiffPt,::DiffPt)
```
then julia would not know whether to use
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)`
```
=#
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
(k::T)(x::DiffPt, y::DiffPt) = _evaluate(k, x, y)
(k::T)(x::DiffPt, y) = _evaluate(k, x, DiffPt(y))
(k::T)(x, y::DiffPt) = _evaluate(k, DiffPt(x), y)
end