Skip to content

Commit 77fdd82

Browse files
authored
Merge pull request #61 from cesmix-mit/kernels
Merge kernel functions into main
2 parents f1acd0f + dcf6c9e commit 77fdd82

File tree

4 files changed

+287
-7
lines changed

4 files changed

+287
-7
lines changed

src/Kernels/distances.jl

+47
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ function Euclidean(
5050
Euclidean(Cinv, Csqrt)
5151
end
5252

53+
"""
54+
compute_distance(A, B, d)
55+
56+
Compute the distance between features A and B using distance metric d.
57+
"""
5358
function compute_distance(B1::Vector{T}, B2::Vector{T}, e::Euclidean) where {T<:Real}
5459
(B1 - B2)' * e.Cinv * (B1 - B2)
5560
end
@@ -61,3 +66,45 @@ function compute_distance(
6166
) where {T<:Real}
6267
tr(e.Csqrt * (C1 - C2)' * e.Cinv * (C1 - C2) * e.Csqrt)
6368
end
69+
70+
"""
71+
compute_gradx_distance(A, B, d)
72+
73+
Compute gradient of the distance between features A and B using distance metric d, with respect to the first argument (A).
74+
"""
75+
function compute_gradx_distance(
76+
A::T,
77+
B::T,
78+
e::Euclidean
79+
) where {T<:Vector{<:Real}}
80+
81+
return 2 * e.Cinv * (A - B)
82+
end
83+
84+
"""
85+
compute_grady_distance(A, B, d)
86+
87+
Compute gradient of the distance between features A and B using distance metric d, with respect to the second argument (B).
88+
"""
89+
function compute_grady_distance(
90+
A::T,
91+
B::T,
92+
e::Euclidean
93+
) where {T<:Vector{<:Real}}
94+
95+
return -2 * e.Cinv * (A - B)
96+
end
97+
98+
"""
99+
compute_gradxy_distance(A, B, d)
100+
101+
Compute second-order cross derivative of the distance between features A and B using distance metric d.
102+
"""
103+
function compute_gradxy_distance(
104+
A::T,
105+
B::T,
106+
e::Euclidean
107+
) where {T<:Vector{<:Real}}
108+
109+
return -2 * e.Cinv
110+
end

src/Kernels/divergences.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
## Discrepancies
2+
"""
3+
Divergence
4+
5+
A struct of abstract type Divergence produces a measure of discrepancy between two probability distributions. Discepancies may take as argument analytical distributions or sets of samples representing empirical distributions.
6+
"""
7+
abstract type Divergence end
8+
9+
10+
"""
11+
KernelSteinDiscrepancy <: Divergence
12+
score :: Function
13+
knl :: Kernel
14+
15+
Computes the kernel Stein discrepancy between distributions p (from which samples are provided) and q (for which the score is provided) based on the RKHS defined by kernel k.
16+
"""
17+
struct KernelSteinDiscrepancy <: Divergence
18+
score :: Function
19+
kernel :: Kernel
20+
end
21+
22+
function KernelSteinDiscrepancy(; score, kernel)
23+
return KernelSteinDiscrepancy(score, kernel)
24+
end
25+
26+
27+
function compute_divergence(
28+
x :: Vector{T},
29+
div :: KernelSteinDiscrepancy,
30+
) where T <: Union{Real, Vector{<:Real}}
31+
32+
N = length(x)
33+
sq = div.score.(x)
34+
k = div.kernel
35+
36+
ksd = 0.0
37+
for i = 1:N
38+
for j = i:N
39+
m = (i == j) ? 1 : 2
40+
sks = sq[i]' * compute_kernel(x[i], x[j], k) * sq[j]
41+
sk = sq[i]' * compute_grady_kernel(x[i], x[j], k)
42+
ks = compute_gradx_kernel(x[i], x[j], k)' * sq[j]
43+
trk = tr(compute_gradxy_kernel(x[i], x[j], k))
44+
45+
ksd += m * (sks + sk + ks + trk) / (N*(N-1.0))
46+
end
47+
end
48+
return ksd
49+
end

src/Kernels/kernels.jl

+163-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
include("features.jl")
22
include("distances.jl")
33

4-
export Distance, Forstner, compute_distance, Euclidean
5-
export Feature, GlobalMean, CorrelationMatrix, compute_feature, compute_features
6-
export Kernel, DotProduct, get_parameters, RBF, compute_kernel, KernelMatrix
74
###############
85
"""
96
Kernel
@@ -42,7 +39,7 @@ end
4239
"""
4340
RBF <: Kernel
4441
d :: Distance function
45-
α :: Reguarlization parameter
42+
α :: Regularization parameter
4643
ℓ :: Length-scale parameter
4744
β :: Scale parameter
4845
@@ -51,7 +48,7 @@ end
5148
5249
k(A, B) = β \exp( -\frac{1}{2} d(A,B)/ℓ^2 ) + α δ(A, B)
5350
"""
54-
struct RBF <: Kernel
51+
mutable struct RBF <: Kernel
5552
d::Distance
5653
α::Real
5754
::Real
@@ -74,7 +71,136 @@ function compute_kernel(
7471
r::RBF,
7572
) where {T<:Union{Vector{<:Real},Symmetric{<:Real,<:Matrix{<:Real}}}}
7673
d2 = compute_distance(A, B, r.d)
77-
r.β * exp(-0.5 * d2 / r.ℓ)
74+
r.β * exp(-0.5 * d2 / r.^2)
75+
end
76+
77+
"""
78+
compute_gradx_kernel(A, B, k)
79+
80+
Compute gradient of the kernel between features A and B using kernel k, with respect to the first argument (A).
81+
"""
82+
function compute_gradx_kernel(
83+
A::T,
84+
B::T,
85+
r::RBF,
86+
) where {T<:Vector{<:Real}}
87+
88+
k = compute_kernel(A, B, r)
89+
∇d = compute_gradx_distance(A, B, r.d)
90+
return -0.5 * k * ∇d / r.^2
91+
end
92+
93+
"""
94+
compute_grady_kernel(A, B, k)
95+
96+
Compute gradient of the kernel between features A and B using kernel k, with respect to the second argument (B).
97+
"""
98+
function compute_grady_kernel(
99+
A::T,
100+
B::T,
101+
r::RBF,
102+
) where {T<:Vector{<:Real}}
103+
104+
k = compute_kernel(A, B, r)
105+
∇d = compute_grady_distance(A, B, r.d)
106+
return -0.5 * k * ∇d / r.^2
107+
end
108+
109+
"""
110+
compute_gradxy_kernel(A, B, k)
111+
112+
Compute the second-order cross derivative of the kernel between features A and B using kernel k.
113+
"""
114+
function compute_gradxy_kernel(
115+
A::T,
116+
B::T,
117+
r::RBF,
118+
) where {T<:Vector{<:Real}}
119+
120+
k = compute_kernel(A, B, r)
121+
∇xd = compute_gradx_distance(A, B, r.d)
122+
∇yd = compute_grady_distance(A, B, r.d)
123+
∇xyd = compute_gradxy_distance(A, B, r.d)
124+
125+
return k .* ( -0.5 * ∇xyd / r.^2 .+ 0.25 * ∇xd'*∇yd / r.^4 )
126+
127+
end
128+
129+
"""
130+
InverseMultiquadric <: Kernel
131+
d :: Distance function
132+
c2 :: Squared constant parameter
133+
ℓ :: Length-scale parameter
134+
135+
Computes the inverse multiquadric (IMQ) kernel, i.e.,
136+
137+
k(A, B) = (c^2 + d(A,B)/β^2)^{-1/2}
138+
"""
139+
mutable struct InverseMultiquadric <: Kernel
140+
d::Distance
141+
c2::Real
142+
::Real
143+
144+
InverseMultiquadric(d, c2, ℓ) = (
145+
@assert (0 < c2);
146+
@assert (0 < ℓ);
147+
new(d, c2, ℓ)
148+
)
149+
end
150+
# default will be 1.0 for c^2
151+
InverseMultiquadric(d; c2=1.0, ℓ=1.0) = InverseMultiquadric(d, c2, ℓ)
152+
153+
get_parameters(k::InverseMultiquadric) = (k.c2, k.ℓ)
154+
155+
156+
function compute_kernel(
157+
A::T,
158+
B::T,
159+
r::InverseMultiquadric,
160+
) where {T<:Union{Vector{<:Real},Symmetric{<:Real,<:Matrix{<:Real}}}}
161+
162+
d2 = compute_distance(A, B, r.d)
163+
(r.c2 + d2 / r.^2)^(-0.5)
164+
end
165+
166+
function compute_gradx_kernel(
167+
A::T,
168+
B::T,
169+
r::InverseMultiquadric,
170+
) where {T<:Vector{<:Real}}
171+
172+
d2 = compute_distance(A, B, r.d)
173+
∇d = compute_gradx_distance(A, B, r.d)
174+
175+
return -0.5 * ∇d / r.^2 * (r.c2 + d2 / r.^2)^(-1.5)
176+
end
177+
178+
function compute_grady_kernel(
179+
A::T,
180+
B::T,
181+
r::InverseMultiquadric,
182+
) where {T<:Vector{<:Real}}
183+
184+
d2 = compute_distance(A, B, r.d)
185+
∇d = compute_grady_distance(A, B, r.d)
186+
187+
return -0.5 * ∇d / r.^2 * (r.c2 + d2 / r.^2)^(-1.5)
188+
end
189+
190+
191+
function compute_gradxy_kernel(
192+
A::T,
193+
B::T,
194+
r::InverseMultiquadric,
195+
) where {T<:Vector{<:Real}}
196+
197+
d2 = compute_distance(A, B, r.d)
198+
∇xd = compute_gradx_distance(A, B, r.d)
199+
∇yd = compute_grady_distance(A, B, r.d)
200+
∇xyd = compute_gradxy_distance(A, B, r.d)
201+
q = r.c2 + d2 / r.^2
202+
203+
return 3*∇xd*∇yd / (4*r.^4) * q^(-2.5) - ∇xyd / (2*r.^2) * q^(-1.5)
78204
end
79205

80206
"""
@@ -141,3 +267,34 @@ function KernelMatrix(
141267
F2 = compute_feature.(ds2, (f,); dt = dt)
142268
KernelMatrix(F1, F2, k)
143269
end
270+
271+
272+
273+
274+
include("divergences.jl")
275+
export
276+
Distance,
277+
Forstner,
278+
compute_distance,
279+
compute_gradx_distance,
280+
compute_grady_distance,
281+
compute_gradxy_distance,
282+
Euclidean,
283+
Feature,
284+
GlobalMean,
285+
CorrelationMatrix,
286+
compute_feature,
287+
compute_features,
288+
Kernel,
289+
DotProduct,
290+
RBF,
291+
InverseMultiquadric,
292+
get_parameters,
293+
compute_kernel,
294+
compute_gradx_kernel,
295+
compute_grady_kernel,
296+
compute_gradxy_kernel,
297+
KernelMatrix,
298+
Divergence,
299+
KernelSteinDiscrepancy,
300+
compute_divergence

test/kernels/kernel_tests.jl

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using AtomsBase
22
using Unitful, UnitfulAtomic
33
using LinearAlgebra
4+
using Distributions
5+
46
# initialize some fake descriptors
57
d = 8
68
num_atoms = 20
@@ -30,6 +32,9 @@ f_cm = compute_feature.(ld, (cm,))
3032
@test compute_features(ds, gm) == f_gm
3133
@test compute_features(ds, cm) == f_cm
3234

35+
mvn = MvNormal(zeros(d), I(d))
36+
f_mvn = [rand(mvn) for j = 1:100]
37+
3338
## distances
3439
fo = Forstner(1e-16)
3540
e = Euclidean(d)
@@ -42,12 +47,18 @@ e = Euclidean(d)
4247
@test compute_distance(f_cm[1], f_cm[1], fo) < eps()
4348
@test compute_distance(f_cm[1], f_cm[2], fo) > 0.0
4449

50+
@test abs(sum(compute_gradx_distance(f_gm[1], f_gm[1], e))) < eps()
51+
@test abs(sum(compute_gradx_distance(f_gm[1], f_gm[2], e))) > 0.0
52+
@test abs(sum(compute_grady_distance(f_gm[1], f_gm[1], e))) < eps()
53+
@test abs(sum(compute_grady_distance(f_gm[1], f_gm[2], e))) > 0.0
54+
@test compute_gradxy_distance(f_gm[1], f_gm[1], e) == -2.0*I(d)
55+
@test compute_gradxy_distance(f_gm[1], f_gm[2], e) == -2.0*I(d)
56+
4557
## kernels
4658
dp = DotProduct()
4759
rbf_e = RBF(e)
4860
rbf_fo = RBF(fo)
4961

50-
5162
@test typeof(dp) <: Kernel
5263
@test typeof(rbf_e) <: Kernel
5364
@test typeof(rbf_fo) <: Kernel
@@ -64,8 +75,24 @@ rbf_fo = RBF(fo)
6475
@test compute_kernel(f_cm[1], f_cm[2], rbf_e) > 0
6576
@test compute_kernel(f_cm[1], f_cm[2], rbf_fo) > 0
6677

78+
@test abs(sum(compute_gradx_kernel(f_gm[1], f_gm[1], rbf_e))) < eps()
79+
@test abs(sum(compute_gradx_kernel(f_gm[1], f_gm[2], rbf_e))) > 0
80+
@test abs(sum(compute_grady_kernel(f_gm[1], f_gm[1], rbf_e))) < eps()
81+
@test abs(sum(compute_grady_kernel(f_gm[1], f_gm[2], rbf_e))) > 0
82+
@test compute_gradxy_kernel(f_gm[1], f_gm[1], rbf_e) == I(8)
83+
@test abs(sum(compute_gradxy_kernel(f_gm[1], f_gm[2], rbf_e))) > 0
84+
6785
@test typeof(KernelMatrix(f_gm, dp)) <: Symmetric{Float64,Matrix{Float64}}
6886
@test typeof(KernelMatrix(f_cm, dp)) <: Symmetric{Float64,Matrix{Float64}}
6987
@test typeof(KernelMatrix(f_gm, rbf_e)) <: Symmetric{Float64,Matrix{Float64}}
7088
@test typeof(KernelMatrix(f_cm, rbf_e)) <: Symmetric{Float64,Matrix{Float64}}
7189
@test typeof(KernelMatrix(f_cm, rbf_fo)) <: Symmetric{Float64,Matrix{Float64}}
90+
91+
92+
# divergences
93+
std_gauss_score(x) = -x
94+
ksd = KernelSteinDiscrepancy(score=std_gauss_score, kernel=rbf_e)
95+
96+
@test typeof(ksd) <: Divergence
97+
@test compute_divergence(f_mvn, ksd) < compute_divergence(f_gm, ksd)
98+

0 commit comments

Comments
 (0)