You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A struct of abstract type Divergence produces a measure of discrepancy between two probability distributions. Discepancies may take as argument analytical distributions or sets of samples representing empirical distributions.
6
+
"""
7
+
abstract type Divergence end
8
+
9
+
10
+
"""
11
+
KernelSteinDiscrepancy <: Divergence
12
+
score :: Function
13
+
knl :: Kernel
14
+
15
+
Computes the kernel Stein discrepancy between distributions p (from which samples are provided) and q (for which the score is provided) based on the RKHS defined by kernel k.
16
+
"""
17
+
struct KernelSteinDiscrepancy <:Divergence
18
+
score ::Function
19
+
kernel ::Kernel
20
+
end
21
+
22
+
functionKernelSteinDiscrepancy(; score, kernel)
23
+
returnKernelSteinDiscrepancy(score, kernel)
24
+
end
25
+
26
+
27
+
functioncompute_divergence(
28
+
x ::Vector{T},
29
+
div ::KernelSteinDiscrepancy,
30
+
) where T <:Union{Real, Vector{<:Real}}
31
+
32
+
N =length(x)
33
+
sq = div.score.(x)
34
+
k = div.kernel
35
+
36
+
ksd =0.0
37
+
for i =1:N
38
+
for j = i:N
39
+
m = (i == j) ?1:2
40
+
sks = sq[i]'*compute_kernel(x[i], x[j], k) * sq[j]
0 commit comments