Skip to content

Commit c62d8f8

Browse files
committed
add divergence functions and tests
Functions which compute the divergence between two probability distribution. This commit includes kernel Stein discrepancy as a divergence metric and associated tests.
1 parent 2c0f1e1 commit c62d8f8

File tree

3 files changed

+108
-38
lines changed

3 files changed

+108
-38
lines changed

Diff for: src/Kernels/divergences.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
## Discrepancies
2+
"""
3+
Divergence
4+
5+
A struct of abstract type Divergence produces a measure of discrepancy between two probability distributions. Discepancies may take as argument analytical distributions or sets of samples representing empirical distributions.
6+
"""
7+
abstract type Divergence end
8+
9+
10+
"""
11+
KernelSteinDiscrepancy <: Divergence
12+
score :: Function
13+
knl :: Kernel
14+
15+
Computes the kernel Stein discrepancy between distributions p (from which samples are provided) and q (for which the score is provided) based on the RKHS defined by kernel k.
16+
"""
17+
struct KernelSteinDiscrepancy <: Divergence
18+
score :: Function
19+
kernel :: Kernel
20+
end
21+
22+
function KernelSteinDiscrepancy(; score, kernel)
23+
return KernelSteinDiscrepancy(score, kernel)
24+
end
25+
26+
27+
function compute_divergence(
28+
x :: Vector{T},
29+
div :: KernelSteinDiscrepancy,
30+
) where T <: Union{Real, Vector{<:Real}}
31+
32+
N = length(x)
33+
sq = div.score.(x)
34+
k = div.kernel
35+
36+
ksd = 0.0
37+
for i = 1:N
38+
for j = i:N
39+
m = (i == j) ? 1 : 2
40+
sks = sq[i]' * compute_kernel(x[i], x[j], k) * sq[j]
41+
sk = sq[i]' * compute_grady_kernel(x[i], x[j], k)
42+
ks = compute_gradx_kernel(x[i], x[j], k)' * sq[j]
43+
trk = tr(compute_gradxy_kernel(x[i], x[j], k))
44+
45+
ksd += m * (sks + sk + ks + trk) / (N*(N-1.0))
46+
end
47+
end
48+
return ksd
49+
end

Diff for: src/Kernels/kernels.jl

+45-38
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,6 @@
11
include("features.jl")
22
include("distances.jl")
33

4-
export
5-
Distance,
6-
Forstner,
7-
compute_distance,
8-
compute_gradx_distance,
9-
compute_grady_distance,
10-
compute_gradxy_distance,
11-
Euclidean,
12-
Feature,
13-
GlobalMean,
14-
CorrelationMatrix,
15-
compute_feature,
16-
compute_features,
17-
Kernel,
18-
DotProduct,
19-
get_parameters,
20-
RBF,
21-
compute_kernel,
22-
compute_gradx_kernel,
23-
compute_grady_kernel,
24-
compute_gradxy_kernel,
25-
KernelMatrix
26-
274
###############
285
"""
296
Kernel
@@ -94,7 +71,7 @@ function compute_kernel(
9471
r::RBF,
9572
) where {T<:Union{Vector{<:Real},Symmetric{<:Real,<:Matrix{<:Real}}}}
9673
d2 = compute_distance(A, B, r.d)
97-
r.β * exp(-0.5 * d2 / r.^2)
74+
r.β * exp(-0.5 * d2 / (2*r.^2))
9875
end
9976

10077
"""
@@ -105,12 +82,12 @@ Compute gradient of the kernel between features A and B using kernel k, with res
10582
function compute_gradx_kernel(
10683
A::T,
10784
B::T,
108-
knl::RBF,
85+
r::RBF,
10986
) where {T<:Vector{<:Real}}
11087

111-
k = compute_kernel(A, B, knl)
112-
∇d = compute_gradx_distance(A, B, knl.d)
113-
return -k * ∇d / (2*knl.^2)
88+
k = compute_kernel(A, B, r)
89+
∇d = compute_gradx_distance(A, B, r.d)
90+
return -k * ∇d / (2*r.^2)
11491
end
11592

11693
"""
@@ -121,12 +98,12 @@ Compute gradient of the kernel between features A and B using kernel k, with res
12198
function compute_grady_kernel(
12299
A::T,
123100
B::T,
124-
knl::RBF,
101+
r::RBF,
125102
) where {T<:Vector{<:Real}}
126103

127-
k = compute_kernel(A, B, knl)
128-
∇d = compute_grady_distance(A, B, knl.d)
129-
return -k * ∇d / (2*knl.^2)
104+
k = compute_kernel(A, B, r)
105+
∇d = compute_grady_distance(A, B, r.d)
106+
return -k * ∇d / (2*r.^2)
130107
end
131108

132109
"""
@@ -137,15 +114,15 @@ Compute the second-order cross derivative of the kernel between features A and B
137114
function compute_gradxy_kernel(
138115
A::T,
139116
B::T,
140-
knl::RBF,
117+
r::RBF,
141118
) where {T<:Vector{<:Real}}
142119

143-
k = compute_kernel(A, B, knl)
144-
∇xd = compute_gradx_distance(A, B, knl.d)
145-
∇yd = compute_grady_distance(A, B, knl.d)
146-
∇xyd = compute_gradxy_distance(A, B, knl.d)
120+
k = compute_kernel(A, B, r)
121+
∇xd = compute_gradx_distance(A, B, r.d)
122+
∇yd = compute_grady_distance(A, B, r.d)
123+
∇xyd = compute_gradxy_distance(A, B, r.d)
147124

148-
return k .* ( -∇xyd/(2*knl.^2) .+ ∇xd'*∇yd/(4*knl.^4) )
125+
return k .* ( -∇xyd/(2*r.^2) .+ ∇xd'*∇yd/(4*r.^4) )
149126

150127
end
151128

@@ -213,3 +190,33 @@ function KernelMatrix(
213190
F2 = compute_feature.(ds2, (f,); dt = dt)
214191
KernelMatrix(F1, F2, k)
215192
end
193+
194+
195+
196+
197+
include("divergences.jl")
198+
export
199+
Distance,
200+
Forstner,
201+
compute_distance,
202+
compute_gradx_distance,
203+
compute_grady_distance,
204+
compute_gradxy_distance,
205+
Euclidean,
206+
Feature,
207+
GlobalMean,
208+
CorrelationMatrix,
209+
compute_feature,
210+
compute_features,
211+
Kernel,
212+
DotProduct,
213+
get_parameters,
214+
RBF,
215+
compute_kernel,
216+
compute_gradx_kernel,
217+
compute_grady_kernel,
218+
compute_gradxy_kernel,
219+
KernelMatrix,
220+
Divergence,
221+
KernelSteinDiscrepancy,
222+
compute_divergence

Diff for: test/kernels/kernel_tests.jl

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using AtomsBase
22
using Unitful, UnitfulAtomic
33
using LinearAlgebra
4+
using Distributions
5+
46
# initialize some fake descriptors
57
d = 8
68
num_atoms = 20
@@ -30,6 +32,9 @@ f_cm = compute_feature.(ld, (cm,))
3032
@test compute_features(ds, gm) == f_gm
3133
@test compute_features(ds, cm) == f_cm
3234

35+
mvn = MvNormal(zeros(d), I(d))
36+
f_mvn = [rand(mvn) for j = 1:100]
37+
3338
## distances
3439
fo = Forstner(1e-16)
3540
e = Euclidean(d)
@@ -82,3 +87,12 @@ rbf_fo = RBF(fo)
8287
@test typeof(KernelMatrix(f_gm, rbf_e)) <: Symmetric{Float64,Matrix{Float64}}
8388
@test typeof(KernelMatrix(f_cm, rbf_e)) <: Symmetric{Float64,Matrix{Float64}}
8489
@test typeof(KernelMatrix(f_cm, rbf_fo)) <: Symmetric{Float64,Matrix{Float64}}
90+
91+
92+
# divergences
93+
std_gauss_score(x) = -x
94+
ksd = KernelSteinDiscrepancy(score=std_gauss_score, kernel=rbf_e)
95+
96+
@test typeof(ksd) <: Divergence
97+
@test compute_divergence(f_mvn, ksd) < compute_divergence(f_gm, ksd)
98+

0 commit comments

Comments
 (0)