From 404d959ee366bca1b5dc26eb66ad76454d902ac0 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 8 Dec 2022 14:15:31 -0800 Subject: [PATCH 1/7] Add a Multivariate Normal distribution. --- lib/stats.dx | 25 +++++++++++++++++++++++-- tests/stats-tests.dx | 12 ++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/lib/stats.dx b/lib/stats.dx index e4e688ccb..938cbdb52 100644 --- a/lib/stats.dx +++ b/lib/stats.dx @@ -1,6 +1,8 @@ '# Stats Probability distributions and other functions useful for statistical computing. +import linalg + '## Log-space floating point numbers When working with probability densities, mass functions, distributions, likelihoods, etc., we often work on a logarithmic scale to prevent floating @@ -329,9 +331,28 @@ instance OrderedDist(Uniform, Float, Float) then Exp 0.0 else if (x > d.high) then Exp (-infinity) - else Exp $ log (d.high - x) - log (d.high - d.low) - def quantile(d, q) = d.low + ((d.high - d.low) * q) + else Exp $ log (high - x) - log (high - low) + quantile = \(Uniform low high) q. + low + ((high - low) * q) + + +'## Multivariate probability distributions +Some commonly encountered multivariate distributions. +### Multivariate Normal distribution +The [Multivariate Normal distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) is parameterised by its *mean*, `loc`, and Cholesky-factored `scale`. + +data MultivariateNormalDist n [Ix n] = MultivariateNormal (n=>Float) (LowerTriMat n Float) + +instance Random (MultivariateNormalDist n) (n=>Float) given {n} [Ix n] + draw = \(MultivariateNormal loc chol_cov) k. + std_norm = for i:n. randn (ixkey k i) + loc + for i:n. sum for j:(..i). chol_cov.i.j * std_norm.(inject _ j) +instance Dist (MultivariateNormalDist n) (n=>Float) Float given {n} [Ix n] + density = \(MultivariateNormal loc chol_cov) x. + y = forward_substitute chol_cov (x - loc) + dim = n_to_f (size n) + Exp (-(dim / 2) * log (2 * pi) - sum (log (lower_tri_diag chol_cov)) - 0.5 * dot y y) '## Data summaries Some data summary functions. Note that `mean` is provided by the prelude. diff --git a/tests/stats-tests.dx b/tests/stats-tests.dx index 3f29c1999..cb01c71fb 100644 --- a/tests/stats-tests.dx +++ b/tests/stats-tests.dx @@ -245,6 +245,18 @@ quantile (Uniform 2.0 5.0) 0.2 ~~ 2.6 rand_vec 5 (\k. draw (Uniform 2.0 5.0) k) (new_key 0) :: Fin 5=>Float > [4.610805, 2.740888, 2.510233, 3.040717, 3.731907] +-- multivariate normal + +draw (Uniform 2.0 5.0) (new_key 0) :: Float + +chol_cov : ((i:Fin 2)=>(..i)=>Float) = [[0.2], [-0.3, 0.1]] +loc : (Fin 2=>Float) = [1., 2.] +draw (MultivariateNormal loc chol_cov) (new_key 0) :: (Fin 2=>Float) +> [0.706645, 2.599938] + +ln (density (MultivariateNormal [1., 1] chol_cov) [0.5, 0.5]) ~~ -79.1758 +> True + -- data summaries From a3fa31e1c60e8a5a62db23681f97a90d795c8c11 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 9 Dec 2022 09:36:17 -0800 Subject: [PATCH 2/7] Add expanded GP regression example. --- examples/kernelregression.dx | 54 ++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 0135bd1f8..8788d0f43 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -1,6 +1,7 @@ '# Kernel Regression import linalg +import stats import plot struct ConjGradState(a|VSpace) = @@ -83,11 +84,41 @@ with the Bayes rule, gives the variance of the prediction. ' In this implementation, the conjugate gradient solver is replaced with the cholesky solver from `lib/linalg.dx` for efficiency. -def gp_regress( - kernel: (a, a) -> Float, - xs: n=>a, - ys: n=>Float - ) -> ((a) -> (Float, Float)) given (n|Ix, a) = +def gp {n a} + (kernel: a -> a -> Float) + (xs: n=>a) + (mean_fn: a -> Float) + (noise_var: Float) + : MultivariateNormalDist n = + gram = for i j. kernel xs.i xs.j + loc = for i. mean_fn xs.i + chol_cov = chol (gram + eye *. noise_var) + MultivariateNormal loc chol_cov + +def gp_regress_matrix {n m a} + (kernel: a -> a -> Float) + (xs: n=>a) + (mean_fn: a -> Float) + (noise_var: Float) + (ys: n=>Float) + : (m=>a -> MultivariateNormalDist m) = + prior_gp = gp kernel xs mean_fn noise_var + (MultivariateNormal loc_obs chol_obs) = prior_gp + k_inv_y = chol_solve chol_obs (ys - for i. mean_fn xs.i) -- n + predictive_gp_fn = \xs_pred:m=>a. + k_star = for i j. kernel xs.i xs_pred.j -- n by m + loc = for i. sum for k. k_star.k.i * k_inv_y.k + mean_fn xs_pred.i -- m + gram_pred = for i j. kernel xs_pred.i xs_pred.j + s = for i. chol_solve chol_obs (transpose k_star).i -- m by n + schur = gram_pred - for i j. sum for k. s.i.k * s.j.k + MultivariateNormal loc (chol schur) + predictive_gp_fn + +def gp_regress {n a} + (kernel: a -> a -> Float) + (xs: n=>a) + (ys: n=>Float) + : (a -> (Float&Float)) = noise_var = 0.0001 gram = for i j. kernel xs[i] xs[j] c = chol (gram + eye *. noise_var) @@ -109,3 +140,16 @@ gp_predict = gp_regress (\x y. rbf 0.2 x y) xs ys :html show_plot $ xy_plot xtest vars > + +def mean_fn (x:Float) : Float = 0. * x + +N = Fin 10 +noise_std = 0.001 +[k3, k4, k5] = split_key (new_key 3) +xs_obs : N=>Float = for i. rand (ixkey k3 i) +ys_obs : N=>Float = for i. trueFun xs_obs.i + noise_std * randn (ixkey k4 i) +M = Fin 40 +xs_pred : M=>Float = for i. rand (ixkey k5 i) +gprm_predict : M=>Float -> MultivariateNormalDist M = gp_regress_matrix (rbf 0.2) xs mean_fn 0.001 ys +pred_dist = gprm_predict xs_pred + From 02e8b6688439fa628a23cd9a459b53d322c80bea Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 9 Dec 2022 11:56:25 -0800 Subject: [PATCH 3/7] Update GPRM example. --- examples/kernelregression.dx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 8788d0f43..863b9f572 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -150,6 +150,6 @@ xs_obs : N=>Float = for i. rand (ixkey k3 i) ys_obs : N=>Float = for i. trueFun xs_obs.i + noise_std * randn (ixkey k4 i) M = Fin 40 xs_pred : M=>Float = for i. rand (ixkey k5 i) -gprm_predict : M=>Float -> MultivariateNormalDist M = gp_regress_matrix (rbf 0.2) xs mean_fn 0.001 ys +gprm_predict : M=>Float -> MultivariateNormalDist M = gp_regress_matrix (rbf 0.2) xs_obs mean_fn 0.001 ys_obs pred_dist = gprm_predict xs_pred From 8a402f4fe90d7b297511ef3f4ad957aa4d76e6dc Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Sun, 14 May 2023 21:12:04 -0700 Subject: [PATCH 4/7] Update for new syntax. --- examples/kernelregression.dx | 60 ++++++++++++++++++------------------ lib/stats.dx | 23 +++++++------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 863b9f572..147ecb52a 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -84,41 +84,41 @@ with the Bayes rule, gives the variance of the prediction. ' In this implementation, the conjugate gradient solver is replaced with the cholesky solver from `lib/linalg.dx` for efficiency. -def gp {n a} - (kernel: a -> a -> Float) - (xs: n=>a) - (mean_fn: a -> Float) - (noise_var: Float) - : MultivariateNormalDist n = - gram = for i j. kernel xs.i xs.j - loc = for i. mean_fn xs.i +def gp( + kernel: (a, a) -> Float, + xs: n=>a, + mean_fn: (a) -> Float, + noise_var: Float +) -> MultivariateNormal n given (n|Ix, a) = + gram = for i j. kernel xs[i] xs[j] + loc = for i. mean_fn xs[i] chol_cov = chol (gram + eye *. noise_var) MultivariateNormal loc chol_cov -def gp_regress_matrix {n m a} - (kernel: a -> a -> Float) - (xs: n=>a) - (mean_fn: a -> Float) - (noise_var: Float) - (ys: n=>Float) - : (m=>a -> MultivariateNormalDist m) = +def gp_regress_matrix( + kernel: (a, a) -> Float, + xs: n=>a, + mean_fn: (a) -> Float, + noise_var: Float, + ys: n=> Float +) -> ((m=>a) -> MultivariateNormal m) given (n|Ix, m|Ix, a) = prior_gp = gp kernel xs mean_fn noise_var - (MultivariateNormal loc_obs chol_obs) = prior_gp - k_inv_y = chol_solve chol_obs (ys - for i. mean_fn xs.i) -- n + -- How is prior_gp.loc incorporated? + k_inv_y = chol_solve prior_gp.chol_cov (ys - for i. mean_fn xs[i]) -- n predictive_gp_fn = \xs_pred:m=>a. - k_star = for i j. kernel xs.i xs_pred.j -- n by m - loc = for i. sum for k. k_star.k.i * k_inv_y.k + mean_fn xs_pred.i -- m - gram_pred = for i j. kernel xs_pred.i xs_pred.j - s = for i. chol_solve chol_obs (transpose k_star).i -- m by n - schur = gram_pred - for i j. sum for k. s.i.k * s.j.k + k_star = for i j. kernel xs[i] xs_pred[j] -- n by m + loc = for i. sum for k. k_star[k, i] * k_inv_y[k] + mean_fn xs_pred[i] -- m + gram_pred = for i j. kernel xs_pred[i] xs_pred[j] + s = for i. chol_solve prior_gp.chol_cov (transpose k_star)[i] -- m by n + schur = gram_pred - for i j. sum for k. s[i, k] * s[j, k] MultivariateNormal loc (chol schur) predictive_gp_fn -def gp_regress {n a} - (kernel: a -> a -> Float) - (xs: n=>a) - (ys: n=>Float) - : (a -> (Float&Float)) = +def gp_regress( + kernel: (a, a) -> Float, + xs: n=>a, + ys: n=>Float +) -> ((a) -> (Float, Float)) given (n|Ix, a) = noise_var = 0.0001 gram = for i j. kernel xs[i] xs[j] c = chol (gram + eye *. noise_var) @@ -141,15 +141,15 @@ gp_predict = gp_regress (\x y. rbf 0.2 x y) xs ys :html show_plot $ xy_plot xtest vars > -def mean_fn (x:Float) : Float = 0. * x +def mean_fn(x:Float) -> Float = 0. * x N = Fin 10 noise_std = 0.001 [k3, k4, k5] = split_key (new_key 3) xs_obs : N=>Float = for i. rand (ixkey k3 i) -ys_obs : N=>Float = for i. trueFun xs_obs.i + noise_std * randn (ixkey k4 i) +ys_obs : N=>Float = for i. trueFun xs_obs[i] + noise_std * randn (ixkey k4 i) M = Fin 40 xs_pred : M=>Float = for i. rand (ixkey k5 i) -gprm_predict : M=>Float -> MultivariateNormalDist M = gp_regress_matrix (rbf 0.2) xs_obs mean_fn 0.001 ys_obs +gprm_predict : (M=>Float) -> MultivariateNormal M = gp_regress_matrix (\x y. rbf 0.2 x y) xs_obs mean_fn 0.001 ys_obs pred_dist = gprm_predict xs_pred diff --git a/lib/stats.dx b/lib/stats.dx index 938cbdb52..00ea01ea0 100644 --- a/lib/stats.dx +++ b/lib/stats.dx @@ -331,9 +331,8 @@ instance OrderedDist(Uniform, Float, Float) then Exp 0.0 else if (x > d.high) then Exp (-infinity) - else Exp $ log (high - x) - log (high - low) - quantile = \(Uniform low high) q. - low + ((high - low) * q) + else Exp $ log (d.high - x) - log (d.high - d.low) + def quantile(d, q) = d.low + ((d.high - d.low) * q) '## Multivariate probability distributions @@ -341,18 +340,20 @@ Some commonly encountered multivariate distributions. ### Multivariate Normal distribution The [Multivariate Normal distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) is parameterised by its *mean*, `loc`, and Cholesky-factored `scale`. -data MultivariateNormalDist n [Ix n] = MultivariateNormal (n=>Float) (LowerTriMat n Float) +struct MultivariateNormal(n|Ix) = + loc : (n=>Float) + chol_cov : LowerTriMat n Float -instance Random (MultivariateNormalDist n) (n=>Float) given {n} [Ix n] - draw = \(MultivariateNormal loc chol_cov) k. +instance Random(MultivariateNormal(n), n=>Float) given (n|Ix) + def draw(d, k) = std_norm = for i:n. randn (ixkey k i) - loc + for i:n. sum for j:(..i). chol_cov.i.j * std_norm.(inject _ j) + d.loc + for i:n. sum(for j:(..i). d.chol_cov[i, j] * std_norm[inject j]) -instance Dist (MultivariateNormalDist n) (n=>Float) Float given {n} [Ix n] - density = \(MultivariateNormal loc chol_cov) x. - y = forward_substitute chol_cov (x - loc) +instance Dist(MultivariateNormal(n), n=>Float, Float) given (n|Ix) + def density(d, x) = + y = forward_substitute d.chol_cov (x - d.loc) dim = n_to_f (size n) - Exp (-(dim / 2) * log (2 * pi) - sum (log (lower_tri_diag chol_cov)) - 0.5 * dot y y) + Exp (-(dim / 2) * log (2 * pi) - sum(log (lower_tri_diag d.chol_cov)) - 0.5 * dot y y) '## Data summaries Some data summary functions. Note that `mean` is provided by the prelude. From 4266ad173f1109d3ea4ea867785f66e098b0cc05 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Sun, 18 Jun 2023 21:13:42 -0700 Subject: [PATCH 5/7] Cleaning up GP regression example. --- examples/kernelregression.dx | 57 +++++++++++++----------------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 147ecb52a..596ec956e 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -95,61 +95,46 @@ def gp( chol_cov = chol (gram + eye *. noise_var) MultivariateNormal loc chol_cov -def gp_regress_matrix( +def gp_regress( kernel: (a, a) -> Float, xs: n=>a, + ys: n=>Float, mean_fn: (a) -> Float, - noise_var: Float, - ys: n=> Float + noise_var: Float ) -> ((m=>a) -> MultivariateNormal m) given (n|Ix, m|Ix, a) = - prior_gp = gp kernel xs mean_fn noise_var - -- How is prior_gp.loc incorporated? - k_inv_y = chol_solve prior_gp.chol_cov (ys - for i. mean_fn xs[i]) -- n + prior_gram = for i j. kernel xs[i] xs[j] + prior_chol_cov = chol (prior_gram + eye *. noise_var) + k_inv_y = chol_solve prior_chol_cov (ys - for i. mean_fn xs[i]) predictive_gp_fn = \xs_pred:m=>a. - k_star = for i j. kernel xs[i] xs_pred[j] -- n by m - loc = for i. sum for k. k_star[k, i] * k_inv_y[k] + mean_fn xs_pred[i] -- m + k_star = for i j. kernel xs[i] xs_pred[j] + loc = for i. sum for k. k_star[k, i] * k_inv_y[k] + mean_fn xs_pred[i] gram_pred = for i j. kernel xs_pred[i] xs_pred[j] - s = for i. chol_solve prior_gp.chol_cov (transpose k_star)[i] -- m by n + s = for i. chol_solve prior_chol_cov (transpose k_star)[i] schur = gram_pred - for i j. sum for k. s[i, k] * s[j, k] - MultivariateNormal loc (chol schur) + MultivariateNormal loc (chol (schur + eye *. noise_var)) predictive_gp_fn -def gp_regress( - kernel: (a, a) -> Float, - xs: n=>a, - ys: n=>Float -) -> ((a) -> (Float, Float)) given (n|Ix, a) = - noise_var = 0.0001 - gram = for i j. kernel xs[i] xs[j] - c = chol (gram + eye *. noise_var) - alpha = chol_solve c ys - predict = \x. - k' = for i. kernel xs[i] x - mu = sum for i. alpha[i] * k'[i] - alpha' = chol_solve c k' - var = kernel x x + noise_var - sum for i. k'[i] * alpha'[i] - (mu, var) - predict - -gp_predict = gp_regress (\x y. rbf 0.2 x y) xs ys - -(gp_preds, vars) = unzip (map gp_predict xtest) +def mean_fn(x:Float) -> Float = 0. * x +gp_predict_fn : (Nxtest=>Float) -> MultivariateNormal Nxtest = gp_regress (\x y. rbf 0.2 x y) xs ys mean_fn 0.5 +gp_predict_test = gp_predict_fn xtest +stddev_pred = lower_tri_diag gp_predict_test.chol_cov -:html show_plot $ xyc_plot xtest gp_preds (map sqrt vars) +:html show_plot $ xyc_plot xtest gp_predict_test.loc stddev_pred > -:html show_plot $ xy_plot xtest vars +:html show_plot $ xy_plot xtest stddev_pred > -def mean_fn(x:Float) -> Float = 0. * x N = Fin 10 noise_std = 0.001 [k3, k4, k5] = split_key (new_key 3) -xs_obs : N=>Float = for i. rand (ixkey k3 i) +xs_obs : N=>Float = for i. rand (ixkey k3 i) ys_obs : N=>Float = for i. trueFun xs_obs[i] + noise_std * randn (ixkey k4 i) M = Fin 40 xs_pred : M=>Float = for i. rand (ixkey k5 i) -gprm_predict : (M=>Float) -> MultivariateNormal M = gp_regress_matrix (\x y. rbf 0.2 x y) xs_obs mean_fn 0.001 ys_obs +gprm_predict : (M=>Float) -> MultivariateNormal M = gp_regress (\x y. rbf 0.2 x y) xs_obs ys_obs mean_fn 0.001 pred_dist = gprm_predict xs_pred - +stddev_pred +gp_predict_test.loc +xtest From 8d8990268433bde43a1e120878b7bf639030d8e2 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 19 Jul 2023 22:11:46 -0700 Subject: [PATCH 6/7] Cleanup and bug fixes to kernelregression example. --- examples/kernelregression.dx | 39 +++++++++++------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 596ec956e..9c6b80749 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -102,39 +102,24 @@ def gp_regress( mean_fn: (a) -> Float, noise_var: Float ) -> ((m=>a) -> MultivariateNormal m) given (n|Ix, m|Ix, a) = - prior_gram = for i j. kernel xs[i] xs[j] - prior_chol_cov = chol (prior_gram + eye *. noise_var) - k_inv_y = chol_solve prior_chol_cov (ys - for i. mean_fn xs[i]) + prior_gp = gp kernel xs mean_fn noise_var + gram_obs_inv_y = chol_solve prior_gp.chol_cov (ys - prior_gp.loc) predictive_gp_fn = \xs_pred:m=>a. - k_star = for i j. kernel xs[i] xs_pred[j] - loc = for i. sum for k. k_star[k, i] * k_inv_y[k] + mean_fn xs_pred[i] - gram_pred = for i j. kernel xs_pred[i] xs_pred[j] - s = for i. chol_solve prior_chol_cov (transpose k_star)[i] - schur = gram_pred - for i j. sum for k. s[i, k] * s[j, k] - MultivariateNormal loc (chol (schur + eye *. noise_var)) + gram_pred_obs = for i j. kernel xs_pred[i] xs[j] + loc = gram_pred_obs **. gram_obs_inv_y + (for i. mean_fn xs_pred[i]) + gram_pred = (for i j. kernel xs_pred[i] xs_pred[j]) + eye *. noise_var + gram_obs_inv_gram_pred_obs = for i. chol_solve prior_gp.chol_cov gram_pred_obs[i] + schur = gram_pred - gram_obs_inv_gram_pred_obs ** (transpose gram_pred_obs) + MultivariateNormal loc (chol schur) predictive_gp_fn def mean_fn(x:Float) -> Float = 0. * x -gp_predict_fn : (Nxtest=>Float) -> MultivariateNormal Nxtest = gp_regress (\x y. rbf 0.2 x y) xs ys mean_fn 0.5 +gp_predict_fn : (Nxtest=>Float) -> MultivariateNormal Nxtest = gp_regress (\x y. rbf 0.2 x y) xs ys mean_fn 0.0001 gp_predict_test = gp_predict_fn xtest -stddev_pred = lower_tri_diag gp_predict_test.chol_cov +var_pred = for i. vdot gp_predict_test.chol_cov[i] gp_predict_test.chol_cov[i] -:html show_plot $ xyc_plot xtest gp_predict_test.loc stddev_pred +:html show_plot $ xyc_plot xtest gp_predict_test.loc (map sqrt var_pred) > -:html show_plot $ xy_plot xtest stddev_pred +:html show_plot $ xy_plot xtest var_pred > - - -N = Fin 10 -noise_std = 0.001 -[k3, k4, k5] = split_key (new_key 3) -xs_obs : N=>Float = for i. rand (ixkey k3 i) -ys_obs : N=>Float = for i. trueFun xs_obs[i] + noise_std * randn (ixkey k4 i) -M = Fin 40 -xs_pred : M=>Float = for i. rand (ixkey k5 i) -gprm_predict : (M=>Float) -> MultivariateNormal M = gp_regress (\x y. rbf 0.2 x y) xs_obs ys_obs mean_fn 0.001 -pred_dist = gprm_predict xs_pred -stddev_pred -gp_predict_test.loc -xtest From c3c7eb150d598cddaa09ebf6b8b4435f372f9656 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 24 Jul 2023 22:19:26 -0700 Subject: [PATCH 7/7] Work/cleanup on MVN in stats.dx and kernelregression example. --- examples/kernelregression.dx | 16 ++++++++++++---- tests/stats-tests.dx | 18 +++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 9c6b80749..760149aa0 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -44,7 +44,7 @@ The optimal coefficients are found by solving a linear system $\alpha=G^{-1}y$,\ where $G_{ij}:=k(x_i, x_j)+\delta_{ij}\lambda$, $\lambda>0$ and $y = (y_1,\dots,y_N)^\top\in\mathbb R^N$ -- Synthetic data -Nx = Fin 100 +Nx = Fin 20 noise = 0.1 [k1, k2] = split_key (new_key 0) @@ -70,9 +70,15 @@ Nxtest = Fin 1000 xtest : Nxtest=>Float = for i. rand (ixkey k1 i) preds = map predict xtest +-- True function. +:html show_plot $ xy_plot xtest (map trueFun xtest) +> + +-- Observed values. :html show_plot $ xy_plot xs ys > +-- Ridge regression prediction. :html show_plot $ xy_plot xtest preds > @@ -115,11 +121,13 @@ def gp_regress( def mean_fn(x:Float) -> Float = 0. * x gp_predict_fn : (Nxtest=>Float) -> MultivariateNormal Nxtest = gp_regress (\x y. rbf 0.2 x y) xs ys mean_fn 0.0001 -gp_predict_test = gp_predict_fn xtest -var_pred = for i. vdot gp_predict_test.chol_cov[i] gp_predict_test.chol_cov[i] +gp_predict_dist = gp_predict_fn xtest +var_pred = for i. vdot gp_predict_dist.chol_cov[i] gp_predict_dist.chol_cov[i] -:html show_plot $ xyc_plot xtest gp_predict_test.loc (map sqrt var_pred) +-- GP posterior predictive mean, colored by variance. +:html show_plot $ xyc_plot xtest gp_predict_dist.loc (map sqrt var_pred) > +-- Posterior predictive variance. :html show_plot $ xy_plot xtest var_pred > diff --git a/tests/stats-tests.dx b/tests/stats-tests.dx index cb01c71fb..cd35a2868 100644 --- a/tests/stats-tests.dx +++ b/tests/stats-tests.dx @@ -249,7 +249,23 @@ rand_vec 5 (\k. draw (Uniform 2.0 5.0) k) (new_key 0) :: Fin 5=>Float draw (Uniform 2.0 5.0) (new_key 0) :: Float -chol_cov : ((i:Fin 2)=>(..i)=>Float) = [[0.2], [-0.3, 0.1]] +chol_cov_mat : Fin 2=>Fin 2=>Float = [[0.2, 0.], [-0.3, 0.1]] + +> Compiler bug! +> > Please report this at github.com/google-research/dex-lang/i +> > +> > Unexpected table: chol_co.1 +> > CallStack (from HasCallStack): +> > error, called at src/lib/Simplify.hs:571:22 in dex-0.1.0. +-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject(to=Fin 2, j)] +-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, (ordinal j)@_] +-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject j] + +-- I think this used to work. +> > Type error: +> > Expected: (RangeTo (Fin 2) 0) +> > Actual: (Fin 1) +chol_cov : (i:Fin 2)=>(..i)=>Float = [[0.2], [-0.3, 0.1]] loc : (Fin 2=>Float) = [1., 2.] draw (MultivariateNormal loc chol_cov) (new_key 0) :: (Fin 2=>Float) > [0.706645, 2.599938]