Skip to content

Commit

Permalink
Update README example to SymNum + update autodiff support description
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Sep 30, 2024
1 parent 2fb5a4c commit 949b94e
Showing 1 changed file with 99 additions and 58 deletions.
157 changes: 99 additions & 58 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
<div style="text-align: center;" align="center">
<h1 style="text-align: center;" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular-light-text.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular.svg">
<img alt="Mici logo" src="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular.svg" width="400px">
</picture>
<p>
<a href="https://badge.fury.io/py/mici">
<img src="https://badge.fury.io/py/mici.svg" alt="PyPI version"/>
</a>
<a href="https://zenodo.org/badge/latestdoi/52494384">
<img src="https://zenodo.org/badge/52494384.svg" alt="DOI"/>
</a>
<a href="https://github.com/matt-graham/mici/actions/workflows/tests.yml">
<img src="https://github.com/matt-graham/mici/actions/workflows/tests.yml/badge.svg" alt="Test status" />
</a>
<a href="https://matt-graham.github.io/mici">
<img src="https://github.com/matt-graham/mici/actions/workflows/docs.yml/badge.svg" alt="Documentation status" />
</a>
</p>
</h1>
<div style="text-align: center;" align="center">
<a href="https://badge.fury.io/py/mici">
<img src="https://badge.fury.io/py/mici.svg" alt="PyPI version"/>
</a>
<a href="https://zenodo.org/badge/latestdoi/52494384">
<img src="https://zenodo.org/badge/52494384.svg" alt="DOI"/>
</a>
<a href="https://github.com/matt-graham/mici/actions/workflows/tests.yml">
<img src="https://github.com/matt-graham/mici/actions/workflows/tests.yml/badge.svg" alt="Test status" />
</a>
<a href="https://matt-graham.github.io/mici">
<img src="https://github.com/matt-graham/mici/actions/workflows/docs.yml/badge.svg" alt="Documentation status" />
</a>
</div>

**Mici** is a Python package providing implementations of *Markov chain Monte
Expand All @@ -34,6 +34,10 @@ Key features include
extend the package,
* a pure Python code base with minimal dependencies,
allowing easy integration within other code,
* built-in support for several automatic differentiation frameworks, including
[JAX](https://jax.readthedocs.io/en/latest/) and
[Autograd](https://github.com/HIPS/autograd), or the option to supply your own
derivative functions,
* implementations of MCMC methods for sampling from distributions on embedded
manifolds implicitly-defined by a constraint equation and distributions on
Riemannian manifolds with a user-specified metric,
Expand Down Expand Up @@ -63,6 +67,14 @@ pip install git+https://github.com/matt-graham/mici
If available in the installed Python environment the following additional
packages provide extra functionality and features

* [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is
available the traces (dictionary) output of a sampling run can be directly
converted to an `arviz.InferenceData` container object using
`arviz.convert_to_inference_data` or implicitly converted by passing the
traces dictionary as the `data` argument
[to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html),
allowing straightforward use of the ArviZ's extensive visualisation and
diagnostic functions.
* [Autograd](https://github.com/HIPS/autograd): if available Autograd will
be used to automatically compute the required derivatives of the model
functions (providing they are specified using functions from the
Expand All @@ -74,15 +86,22 @@ packages provide extra functionality and features
serialisation (via [dill](https://github.com/uqfoundation/dill)) of a much
wider range of types, including of Autograd generated functions. Both
Autograd and multiprocess can be installed alongside Mici by running `pip
install mici[autodiff]`.
* [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is
available the traces (dictionary) output of a sampling run can be directly
converted to an `arviz.InferenceData` container object using
`arviz.convert_to_inference_data` or implicitly converted by passing the
traces dictionary as the `data` argument
[to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html),
allowing straightforward use of the ArviZ's extensive visualisation and
diagnostic functions.
install mici[autograd]`.
* [JAX](https://jax.readthedocs.io/en/latest/): if available JAX will be used to
automatically compute the required derivatives of the model functions (providing
they are specified using functions from the [`jax`
interface](https://jax.readthedocs.io/en/latest/jax.html)). To sample chains
parallel using JAX functions you also need to install
[multiprocess](https://github.com/uqfoundation/multiprocess), though note due to
JAX's use of multithreading which [is incompatible with forking child
processes](https://docs.python.org/3/library/os.html#os.fork), this can result in
deadlock. Both JAX and multiprocess can be installed alongside Mici by running `pip
install mici[jax]`.
* [SymNum](https://github.com/matt-graham/symnum): if available SymNum will be used to
automatically compute the required derivatives of the model functions (providing
they are specified using functions from the [`symnum.numpy`
interface](https://matt-graham.github.io/symnum/symnum.numpy.html)). Symnum can be
installed alongside Mici by running `pip install mici[symnum]`.

## Why Mici?

Expand Down Expand Up @@ -122,7 +141,7 @@ chains in Python can dominate the computational cost, making sampling much
slower than packages which outsource the sampling loop to a efficient compiled
implementation.

## Overview of package
## Overview of package

API documentation for the package is available
[here](https://matt-graham.github.io/mici/). The three main user-facing
Expand Down Expand Up @@ -257,22 +276,21 @@ The manifold MCMC methods implemented in Mici have been used in several research

<img src='https://raw.githubusercontent.com/matt-graham/mici/main/images/torus-samples.gif' width='360px'/>

A simple complete example of using the package to compute approximate samples
from a distribution on a two-dimensional torus embedded in a three-dimensional
space is given below. The computed samples are visualized in the animation
above. Here we use `autograd` to automatically construct functions to calculate
the required derivatives (gradient of negative log density of target
distribution and Jacobian of constraint function), sample four chains in
parallel using `multiprocess`, use `arviz` to calculate diagnostics and use
`matplotlib` to plot the samples.

> ⚠️ **If you do not have [`multiprocess`](https://github.com/uqfoundation/multiprocess) installed the example code below will hang or raise an error when sampling the chains as the inbuilt `multiprocessing` module does not support pickling Autograd functions.**
A simple complete example of using the package to compute approximate samples from a
distribution on a two-dimensional torus embedded in a three-dimensional space is given
below. The computed samples are visualized in the animation above. Here we use
[SymNum](https://github.com/matt-graham/symnum) to automatically construct functions to
calculate the required derivatives (gradient of negative log density of target
distribution and Jacobian of constraint function), sample four chains in parallel using
`multiprocessing`, use [ArviZ](https://python.arviz.org/en/stable/) to calculate
diagnostics and use [Matplotlib](https://matplotlib.org/) to plot the samples.

```Python
from mici import systems, integrators, samplers
import autograd.numpy as np
import mici
import numpy as np
import symnum
import symnum.numpy as snp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import arviz

Expand All @@ -281,44 +299,62 @@ R = 1.0 # toroidal radius ∈ (0, ∞)
r = 0.5 # poloidal radius ∈ (0, R)
α = 0.9 # density fluctuation amplitude ∈ [0, 1)

# State dimension
dim_q = 3


# Define constraint function such that the set {q : constr(q) == 0} is a torus
@symnum.numpify(dim_q)
def constr(q):
x, y, z = q.T
return np.stack([((x**2 + y**2)**0.5 - R)**2 + z**2 - r**2], -1)
x, y, z = q
return snp.array([((x**2 + y**2) ** 0.5 - R) ** 2 + z**2 - r**2])


# Define negative log density for the target distribution on torus
# (with respect to 2D 'area' measure for torus)
@symnum.numpify(dim_q)
def neg_log_dens(q):
x, y, z = q.T
θ = np.arctan2(y, x)
ϕ = np.arctan2(z, x / np.cos(θ) - R)
return np.log1p(r * np.cos(ϕ) / R) - np.log1p(np.sin(4*θ) * np.cos(ϕ) * α)
x, y, z = q
θ = snp.arctan2(y, x)
ϕ = snp.arctan2(z, x / snp.cos(θ) - R)
return snp.log1p(r * snp.cos(ϕ) / R) - snp.log1p(snp.sin(4 * θ) * snp.cos(ϕ) * α)


# Specify constrained Hamiltonian system with default identity metric
system = systems.DenseConstrainedEuclideanMetricSystem(neg_log_dens, constr)
system = mici.systems.DenseConstrainedEuclideanMetricSystem(
neg_log_dens,
constr,
backend="symnum",
)

# System is constrained therefore use constrained leapfrog integrator
integrator = integrators.ConstrainedLeapfrogIntegrator(system)
integrator = mici.integrators.ConstrainedLeapfrogIntegrator(system)

# Seed a random number generator
rng = np.random.default_rng(seed=1234)

# Use dynamic integration-time HMC implementation as MCMC sampler
sampler = samplers.DynamicMultinomialHMC(system, integrator, rng)
sampler = mici.samplers.DynamicMultinomialHMC(system, integrator, rng)

# Sample initial positions on torus using parameterisation (θ, ϕ) ∈ [0, 2π)²
# x, y, z = (R + r * cos(ϕ)) * cos(θ), (R + r * cos(ϕ)) * sin(θ), r * sin(ϕ)
n_chain = 4
θ_init, ϕ_init = rng.uniform(0, 2 * np.pi, size=(2, n_chain))
q_init = np.stack([
(R + r * np.cos(ϕ_init)) * np.cos(θ_init),
(R + r * np.cos(ϕ_init)) * np.sin(θ_init),
r * np.sin(ϕ_init)], -1)
q_init = np.stack(
[
(R + r * np.cos(ϕ_init)) * np.cos(θ_init),
(R + r * np.cos(ϕ_init)) * np.sin(θ_init),
r * np.sin(ϕ_init),
],
-1,
)


# Define function to extract variables to trace during sampling
def trace_func(state):
x, y, z = state.pos
return {'x': x, 'y': y, 'z': z}
return {"x": x, "y": y, "z": z}


# Sample 4 chains in parallel with 500 adaptive warm up iterations in which the
# integrator step size is tuned, followed by 2000 non-adaptive iterations
Expand All @@ -327,7 +363,7 @@ final_states, traces, stats = sampler.sample_chains(
n_main_iter=2000,
init_states=q_init,
n_process=4,
trace_funcs=[trace_func]
trace_funcs=[trace_func],
)

# Print average accept probability and number of integrator steps per chain
Expand All @@ -340,19 +376,24 @@ for c in range(n_chain):
print(arviz.summary(traces))

# Visualize concatentated chain samples as animated 3D scatter plot
fig = plt.figure(figsize=(4, 4))
ax = Axes3D(fig, [0., 0., 1., 1.], proj_type='ortho')
points_3d, = ax.plot(*(np.concatenate(traces[k]) for k in 'xyz'), '.', ms=0.5)
ax.axis('off')
fig, ax = plt.subplots(
figsize=(4, 4),
subplot_kw={"projection": "3d", "proj_type": "ortho"},
)
(points_3d,) = ax.plot(*(np.concatenate(traces[k]) for k in "xyz"), ".", ms=0.5)
ax.axis("off")
for set_lim in [ax.set_xlim, ax.set_ylim, ax.set_zlim]:
set_lim((-1, 1))


def update(i):
angle = 45 * (np.sin(2 * np.pi * i / 60) + 1)
ax.view_init(elev=angle, azim=angle)
return (points_3d,)

anim = animation.FuncAnimation(fig, update, frames=60, interval=100, blit=True)

anim = animation.FuncAnimation(fig, update, frames=60, interval=100)
plt.show()
```

## References
Expand Down

0 comments on commit 949b94e

Please sign in to comment.