From 949b94eaa36a1171a7c1a3e2192d1f07a9b73548 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 30 Sep 2024 14:58:17 +0100 Subject: [PATCH] Update README example to SymNum + update autodiff support description --- README.md | 157 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 99 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index a876515..7c4eefe 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,23 @@ -
+

Mici logo -

- - PyPI version - - - DOI - - - Test status - - - Documentation status - -

+

+
+ + PyPI version + + + DOI + + + Test status + + + Documentation status +
**Mici** is a Python package providing implementations of *Markov chain Monte @@ -34,6 +34,10 @@ Key features include extend the package, * a pure Python code base with minimal dependencies, allowing easy integration within other code, + * built-in support for several automatic differentiation frameworks, including + [JAX](https://jax.readthedocs.io/en/latest/) and + [Autograd](https://github.com/HIPS/autograd), or the option to supply your own + derivative functions, * implementations of MCMC methods for sampling from distributions on embedded manifolds implicitly-defined by a constraint equation and distributions on Riemannian manifolds with a user-specified metric, @@ -63,6 +67,14 @@ pip install git+https://github.com/matt-graham/mici If available in the installed Python environment the following additional packages provide extra functionality and features + * [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is + available the traces (dictionary) output of a sampling run can be directly + converted to an `arviz.InferenceData` container object using + `arviz.convert_to_inference_data` or implicitly converted by passing the + traces dictionary as the `data` argument + [to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html), + allowing straightforward use of the ArviZ's extensive visualisation and + diagnostic functions. * [Autograd](https://github.com/HIPS/autograd): if available Autograd will be used to automatically compute the required derivatives of the model functions (providing they are specified using functions from the @@ -74,15 +86,22 @@ packages provide extra functionality and features serialisation (via [dill](https://github.com/uqfoundation/dill)) of a much wider range of types, including of Autograd generated functions. Both Autograd and multiprocess can be installed alongside Mici by running `pip - install mici[autodiff]`. - * [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is - available the traces (dictionary) output of a sampling run can be directly - converted to an `arviz.InferenceData` container object using - `arviz.convert_to_inference_data` or implicitly converted by passing the - traces dictionary as the `data` argument - [to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html), - allowing straightforward use of the ArviZ's extensive visualisation and - diagnostic functions. + install mici[autograd]`. + * [JAX](https://jax.readthedocs.io/en/latest/): if available JAX will be used to + automatically compute the required derivatives of the model functions (providing + they are specified using functions from the [`jax` + interface](https://jax.readthedocs.io/en/latest/jax.html)). To sample chains + parallel using JAX functions you also need to install + [multiprocess](https://github.com/uqfoundation/multiprocess), though note due to + JAX's use of multithreading which [is incompatible with forking child + processes](https://docs.python.org/3/library/os.html#os.fork), this can result in + deadlock. Both JAX and multiprocess can be installed alongside Mici by running `pip + install mici[jax]`. + * [SymNum](https://github.com/matt-graham/symnum): if available SymNum will be used to + automatically compute the required derivatives of the model functions (providing + they are specified using functions from the [`symnum.numpy` + interface](https://matt-graham.github.io/symnum/symnum.numpy.html)). Symnum can be + installed alongside Mici by running `pip install mici[symnum]`. ## Why Mici? @@ -122,7 +141,7 @@ chains in Python can dominate the computational cost, making sampling much slower than packages which outsource the sampling loop to a efficient compiled implementation. - ## Overview of package +## Overview of package API documentation for the package is available [here](https://matt-graham.github.io/mici/). The three main user-facing @@ -257,22 +276,21 @@ The manifold MCMC methods implemented in Mici have been used in several research -A simple complete example of using the package to compute approximate samples -from a distribution on a two-dimensional torus embedded in a three-dimensional -space is given below. The computed samples are visualized in the animation -above. Here we use `autograd` to automatically construct functions to calculate -the required derivatives (gradient of negative log density of target -distribution and Jacobian of constraint function), sample four chains in -parallel using `multiprocess`, use `arviz` to calculate diagnostics and use -`matplotlib` to plot the samples. - -> ⚠️ **If you do not have [`multiprocess`](https://github.com/uqfoundation/multiprocess) installed the example code below will hang or raise an error when sampling the chains as the inbuilt `multiprocessing` module does not support pickling Autograd functions.** +A simple complete example of using the package to compute approximate samples from a +distribution on a two-dimensional torus embedded in a three-dimensional space is given +below. The computed samples are visualized in the animation above. Here we use +[SymNum](https://github.com/matt-graham/symnum) to automatically construct functions to +calculate the required derivatives (gradient of negative log density of target +distribution and Jacobian of constraint function), sample four chains in parallel using +`multiprocessing`, use [ArviZ](https://python.arviz.org/en/stable/) to calculate +diagnostics and use [Matplotlib](https://matplotlib.org/) to plot the samples. ```Python -from mici import systems, integrators, samplers -import autograd.numpy as np +import mici +import numpy as np +import symnum +import symnum.numpy as snp import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D import matplotlib.animation as animation import arviz @@ -281,44 +299,62 @@ R = 1.0 # toroidal radius ∈ (0, ∞) r = 0.5 # poloidal radius ∈ (0, R) α = 0.9 # density fluctuation amplitude ∈ [0, 1) +# State dimension +dim_q = 3 + + # Define constraint function such that the set {q : constr(q) == 0} is a torus +@symnum.numpify(dim_q) def constr(q): - x, y, z = q.T - return np.stack([((x**2 + y**2)**0.5 - R)**2 + z**2 - r**2], -1) + x, y, z = q + return snp.array([((x**2 + y**2) ** 0.5 - R) ** 2 + z**2 - r**2]) + # Define negative log density for the target distribution on torus # (with respect to 2D 'area' measure for torus) +@symnum.numpify(dim_q) def neg_log_dens(q): - x, y, z = q.T - θ = np.arctan2(y, x) - ϕ = np.arctan2(z, x / np.cos(θ) - R) - return np.log1p(r * np.cos(ϕ) / R) - np.log1p(np.sin(4*θ) * np.cos(ϕ) * α) + x, y, z = q + θ = snp.arctan2(y, x) + ϕ = snp.arctan2(z, x / snp.cos(θ) - R) + return snp.log1p(r * snp.cos(ϕ) / R) - snp.log1p(snp.sin(4 * θ) * snp.cos(ϕ) * α) + # Specify constrained Hamiltonian system with default identity metric -system = systems.DenseConstrainedEuclideanMetricSystem(neg_log_dens, constr) +system = mici.systems.DenseConstrainedEuclideanMetricSystem( + neg_log_dens, + constr, + backend="symnum", +) # System is constrained therefore use constrained leapfrog integrator -integrator = integrators.ConstrainedLeapfrogIntegrator(system) +integrator = mici.integrators.ConstrainedLeapfrogIntegrator(system) # Seed a random number generator rng = np.random.default_rng(seed=1234) # Use dynamic integration-time HMC implementation as MCMC sampler -sampler = samplers.DynamicMultinomialHMC(system, integrator, rng) +sampler = mici.samplers.DynamicMultinomialHMC(system, integrator, rng) # Sample initial positions on torus using parameterisation (θ, ϕ) ∈ [0, 2π)² # x, y, z = (R + r * cos(ϕ)) * cos(θ), (R + r * cos(ϕ)) * sin(θ), r * sin(ϕ) n_chain = 4 θ_init, ϕ_init = rng.uniform(0, 2 * np.pi, size=(2, n_chain)) -q_init = np.stack([ - (R + r * np.cos(ϕ_init)) * np.cos(θ_init), - (R + r * np.cos(ϕ_init)) * np.sin(θ_init), - r * np.sin(ϕ_init)], -1) +q_init = np.stack( + [ + (R + r * np.cos(ϕ_init)) * np.cos(θ_init), + (R + r * np.cos(ϕ_init)) * np.sin(θ_init), + r * np.sin(ϕ_init), + ], + -1, +) + # Define function to extract variables to trace during sampling def trace_func(state): x, y, z = state.pos - return {'x': x, 'y': y, 'z': z} + return {"x": x, "y": y, "z": z} + # Sample 4 chains in parallel with 500 adaptive warm up iterations in which the # integrator step size is tuned, followed by 2000 non-adaptive iterations @@ -327,7 +363,7 @@ final_states, traces, stats = sampler.sample_chains( n_main_iter=2000, init_states=q_init, n_process=4, - trace_funcs=[trace_func] + trace_funcs=[trace_func], ) # Print average accept probability and number of integrator steps per chain @@ -340,19 +376,24 @@ for c in range(n_chain): print(arviz.summary(traces)) # Visualize concatentated chain samples as animated 3D scatter plot -fig = plt.figure(figsize=(4, 4)) -ax = Axes3D(fig, [0., 0., 1., 1.], proj_type='ortho') -points_3d, = ax.plot(*(np.concatenate(traces[k]) for k in 'xyz'), '.', ms=0.5) -ax.axis('off') +fig, ax = plt.subplots( + figsize=(4, 4), + subplot_kw={"projection": "3d", "proj_type": "ortho"}, +) +(points_3d,) = ax.plot(*(np.concatenate(traces[k]) for k in "xyz"), ".", ms=0.5) +ax.axis("off") for set_lim in [ax.set_xlim, ax.set_ylim, ax.set_zlim]: set_lim((-1, 1)) + def update(i): angle = 45 * (np.sin(2 * np.pi * i / 60) + 1) ax.view_init(elev=angle, azim=angle) return (points_3d,) -anim = animation.FuncAnimation(fig, update, frames=60, interval=100, blit=True) + +anim = animation.FuncAnimation(fig, update, frames=60, interval=100) +plt.show() ``` ## References