Skip to content

Add a Multivariate Normal distribution and expand the GP regression example. #1332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions examples/kernelregression.dx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'# Kernel Regression

import linalg
import stats
import plot

struct ConjGradState(a|VSpace) =
Expand Down Expand Up @@ -43,7 +44,7 @@ The optimal coefficients are found by solving a linear system $\alpha=G^{-1}y$,\
where $G_{ij}:=k(x_i, x_j)+\delta_{ij}\lambda$, $\lambda>0$ and $y = (y_1,\dots,y_N)^\top\in\mathbb R^N$

-- Synthetic data
Nx = Fin 100
Nx = Fin 20
noise = 0.1
[k1, k2] = split_key (new_key 0)

Expand All @@ -69,9 +70,15 @@ Nxtest = Fin 1000
xtest : Nxtest=>Float = for i. rand (ixkey k1 i)
preds = map predict xtest

-- True function.
:html show_plot $ xy_plot xtest (map trueFun xtest)
> <html output>

-- Observed values.
:html show_plot $ xy_plot xs ys
> <html output>

-- Ridge regression prediction.
:html show_plot $ xy_plot xtest preds
> <html output>

Expand All @@ -83,29 +90,44 @@ with the Bayes rule, gives the variance of the prediction.
' In this implementation, the conjugate gradient solver is replaced with the
cholesky solver from `lib/linalg.dx` for efficiency.

def gp_regress(
kernel: (a, a) -> Float,
xs: n=>a,
ys: n=>Float
) -> ((a) -> (Float, Float)) given (n|Ix, a) =
noise_var = 0.0001
def gp(
kernel: (a, a) -> Float,
xs: n=>a,
mean_fn: (a) -> Float,
noise_var: Float
) -> MultivariateNormal n given (n|Ix, a) =
gram = for i j. kernel xs[i] xs[j]
c = chol (gram + eye *. noise_var)
alpha = chol_solve c ys
predict = \x.
k' = for i. kernel xs[i] x
mu = sum for i. alpha[i] * k'[i]
alpha' = chol_solve c k'
var = kernel x x + noise_var - sum for i. k'[i] * alpha'[i]
(mu, var)
predict
loc = for i. mean_fn xs[i]
chol_cov = chol (gram + eye *. noise_var)
MultivariateNormal loc chol_cov

gp_predict = gp_regress (\x y. rbf 0.2 x y) xs ys

(gp_preds, vars) = unzip (map gp_predict xtest)

:html show_plot $ xyc_plot xtest gp_preds (map sqrt vars)
def gp_regress(
kernel: (a, a) -> Float,
xs: n=>a,
ys: n=>Float,
mean_fn: (a) -> Float,
noise_var: Float
) -> ((m=>a) -> MultivariateNormal m) given (n|Ix, m|Ix, a) =
prior_gp = gp kernel xs mean_fn noise_var
gram_obs_inv_y = chol_solve prior_gp.chol_cov (ys - prior_gp.loc)
predictive_gp_fn = \xs_pred:m=>a.
gram_pred_obs = for i j. kernel xs_pred[i] xs[j]
loc = gram_pred_obs **. gram_obs_inv_y + (for i. mean_fn xs_pred[i])
gram_pred = (for i j. kernel xs_pred[i] xs_pred[j]) + eye *. noise_var
gram_obs_inv_gram_pred_obs = for i. chol_solve prior_gp.chol_cov gram_pred_obs[i]
schur = gram_pred - gram_obs_inv_gram_pred_obs ** (transpose gram_pred_obs)
MultivariateNormal loc (chol schur)
predictive_gp_fn

def mean_fn(x:Float) -> Float = 0. * x
gp_predict_fn : (Nxtest=>Float) -> MultivariateNormal Nxtest = gp_regress (\x y. rbf 0.2 x y) xs ys mean_fn 0.0001
gp_predict_dist = gp_predict_fn xtest
var_pred = for i. vdot gp_predict_dist.chol_cov[i] gp_predict_dist.chol_cov[i]

-- GP posterior predictive mean, colored by variance.
:html show_plot $ xyc_plot xtest gp_predict_dist.loc (map sqrt var_pred)
> <html output>

:html show_plot $ xy_plot xtest vars
-- Posterior predictive variance.
:html show_plot $ xy_plot xtest var_pred
> <html output>
22 changes: 22 additions & 0 deletions lib/stats.dx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
'# Stats
Probability distributions and other functions useful for statistical computing.

import linalg

'## Log-space floating point numbers
When working with probability densities, mass functions, distributions,
likelihoods, etc., we often work on a logarithmic scale to prevent floating
Expand Down Expand Up @@ -333,6 +335,26 @@ instance OrderedDist(Uniform, Float, Float)
def quantile(d, q) = d.low + ((d.high - d.low) * q)


'## Multivariate probability distributions
Some commonly encountered multivariate distributions.
### Multivariate Normal distribution
The [Multivariate Normal distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) is parameterised by its *mean*, `loc`, and Cholesky-factored `scale`.

struct MultivariateNormal(n|Ix) =
loc : (n=>Float)
chol_cov : LowerTriMat n Float

instance Random(MultivariateNormal(n), n=>Float) given (n|Ix)
def draw(d, k) =
std_norm = for i:n. randn (ixkey k i)
d.loc + for i:n. sum(for j:(..i). d.chol_cov[i, j] * std_norm[inject j])

instance Dist(MultivariateNormal(n), n=>Float, Float) given (n|Ix)
def density(d, x) =
y = forward_substitute d.chol_cov (x - d.loc)
dim = n_to_f (size n)
Exp (-(dim / 2) * log (2 * pi) - sum(log (lower_tri_diag d.chol_cov)) - 0.5 * dot y y)

'## Data summaries
Some data summary functions. Note that `mean` is provided by the prelude.

Expand Down
28 changes: 28 additions & 0 deletions tests/stats-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,34 @@ quantile (Uniform 2.0 5.0) 0.2 ~~ 2.6
rand_vec 5 (\k. draw (Uniform 2.0 5.0) k) (new_key 0) :: Fin 5=>Float
> [4.610805, 2.740888, 2.510233, 3.040717, 3.731907]

-- multivariate normal

draw (Uniform 2.0 5.0) (new_key 0) :: Float

chol_cov_mat : Fin 2=>Fin 2=>Float = [[0.2, 0.], [-0.3, 0.1]]

> Compiler bug!
> > Please report this at github.com/google-research/dex-lang/i
> >
> > Unexpected table: chol_co.1
> > CallStack (from HasCallStack):
> > error, called at src/lib/Simplify.hs:571:22 in dex-0.1.0.
-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject(to=Fin 2, j)]
-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, (ordinal j)@_]
-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject j]

-- I think this used to work.
> > Type error:
> > Expected: (RangeTo (Fin 2) 0)
> > Actual: (Fin 1)
chol_cov : (i:Fin 2)=>(..i)=>Float = [[0.2], [-0.3, 0.1]]
loc : (Fin 2=>Float) = [1., 2.]
draw (MultivariateNormal loc chol_cov) (new_key 0) :: (Fin 2=>Float)
> [0.706645, 2.599938]

ln (density (MultivariateNormal [1., 1] chol_cov) [0.5, 0.5]) ~~ -79.1758
> True


-- data summaries

Expand Down