Skip to content
Merged
Show file tree
Hide file tree
Changes from 128 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
b2fb557
adding a base class to build density estimators on top of
htjb Nov 26, 2025
fe025a1
tighter prototyping of the functions
htjb Nov 26, 2025
c31ad44
util functions
htjb Nov 26, 2025
aadb5dc
beginning to lay out a kde class as a starting point for v2
htjb Nov 26, 2025
2ffd42b
adding forward and inverse transformations
htjb Nov 27, 2025
c6120a7
realised I can use tensorflow probability with the jax backend
htjb Nov 27, 2025
02e0198
fiddling with the parameter order in the sampling function
htjb Nov 27, 2025
da3b55a
adding in a log prob function with correction
htjb Nov 27, 2025
fc7ea0e
kde reimplementation
htjb Nov 27, 2025
32a8dec
removing old margarine kde class
htjb Nov 27, 2025
8210575
hmmm implementing the conditional inverse transform sampling is trick…
htjb Nov 27, 2025
ce08c15
adding a jax test train split function
htjb Nov 27, 2025
cbc25d0
removing now redundatn processing code
htjb Nov 27, 2025
2596472
rearranging the file structure
htjb Nov 27, 2025
be6ea2a
organising utils into seperate file
htjb Nov 27, 2025
26d7204
an implementation of kmeans in jax
htjb Nov 27, 2025
e885e8f
jax implementation of the silhouette score
htjb Nov 27, 2025
2157315
making sure that the bounds function returns (min, max) and fixing ty…
htjb Nov 27, 2025
28f60e1
adding a generic cluster code to build piecewise estimators from any …
htjb Nov 27, 2025
545645b
better flagging of pycache and ds_store
htjb Nov 27, 2025
7198547
make base class and abstract class with required methods
htjb Nov 28, 2025
ac0b2a2
doesnt need to inherit from the base class
htjb Nov 28, 2025
3911fc7
dont need to init base class here
htjb Nov 28, 2025
752e22f
implementation of the NICE normalising flow
htjb Nov 28, 2025
c790283
bug fix in log prob calculation for kde
htjb Nov 28, 2025
f701de6
depends on flax
htjb Nov 28, 2025
89f1364
better variable name for hidden size
htjb Nov 28, 2025
a62c5f6
updating the log_prob_under_nice function name so it doesnt clash wit…
htjb Nov 28, 2025
916d38a
better reporting of the loss
htjb Nov 28, 2025
32cb9db
realNVP implementation
htjb Nov 28, 2025
e537711
correcting a bug in the bounds estimation
htjb Nov 28, 2025
d82bdbb
calcualte theta ranges seperately for each cluster so that they are w…
htjb Nov 28, 2025
53a9e3c
some doc string improvements
htjb Nov 28, 2025
356cfda
adding a more detailed discription of the origin of the approximate b…
htjb Nov 28, 2025
f44529a
new margarine strap line
htjb Nov 28, 2025
ab5f17e
bumping version number because im excited
htjb Nov 28, 2025
b5b8d6b
fiddling with the pyproject.toml
htjb Nov 28, 2025
c9af1c9
merging
htjb Nov 28, 2025
1475ab6
missing comma
htjb Nov 29, 2025
af61154
need tfp nightly for compatibility with latest jax
htjb Nov 29, 2025
2fe86ba
laying out new statistics code
htjb Nov 29, 2025
ded5d79
code to calcualte kl and bmd
htjb Nov 29, 2025
7a41ff4
jit the log prob functions
htjb Nov 29, 2025
5512bcf
removing old maf class so i can rewrite
htjb Nov 30, 2025
2814a4a
starting to lay out the base of the maf class
htjb Nov 30, 2025
4c16802
adding a to do list briefly so i can keep track
htjb Nov 30, 2025
be26f45
build the mades, masks and inverse and forward passes
htjb Dec 1, 2025
7f20c84
adding in activations after first layers
htjb Dec 1, 2025
547c375
pretty sure this is set up correctly
htjb Dec 1, 2025
7659e96
okay i think this maf is pretty well set up and optimized
htjb Dec 1, 2025
a660b1a
updating the todo
htjb Dec 1, 2025
c5c7b48
modifying the clustered_distribution example
htjb Dec 2, 2025
a6d67be
removing some old code
htjb Dec 2, 2025
39bf3e1
removing old test files
htjb Dec 2, 2025
38d1b51
importing the base from the correct file
htjb Dec 2, 2025
5640fcb
more stable kernel initialisation
htjb Dec 2, 2025
7ef9233
removing ndims in base class
htjb Dec 2, 2025
febb392
adding tests for kde and some of the util functions
htjb Dec 2, 2025
220385e
better name for utils test and formatting on tests
htjb Dec 2, 2025
f713279
tighter constraints on kl and bmd accuracy (I think something wrong w…
htjb Dec 2, 2025
b251d85
testing on a more straight forward problem
htjb Dec 2, 2025
69d6521
i think the latest versions of flax are only tested on 3.11 and above
htjb Dec 2, 2025
bc78dd9
trying to use stochasticity of training to set atol and rtol on kl an…
htjb Dec 2, 2025
a1512c3
playing with the tests. most pass now just trouble with maf training …
htjb Dec 3, 2025
75babff
type in instance check in loglike
htjb Dec 4, 2025
bc9b154
bug fixing the permutations in the maf but still needs some tinkering
htjb Dec 4, 2025
6f526e3
better initialisation of final layer in the mades and some tinkering
htjb Dec 5, 2025
9ce214a
removing maf stuff from v2 branch after splitting into maf branch
htjb Dec 5, 2025
6d24c89
removing maf tests from v2 branch
htjb Dec 5, 2025
c535866
testing the importance sampling functionality
htjb Dec 5, 2025
dafb13f
batched training for nice and realnvp
htjb Dec 5, 2025
8fc3cc7
rough test code for cluster class
htjb Dec 5, 2025
93f1e2c
setting batch size for better performance on toy problems
htjb Dec 9, 2025
9fb9e33
specifying max cluster size
htjb Dec 9, 2025
a009009
seperate theta bounds for each flow causes issues when evaluating log…
htjb Dec 9, 2025
9d3257e
playing with the flow settings but some inf is creeping in somewhere …
htjb Dec 9, 2025
e31a247
hmmm seems the inf is the true kl
htjb Dec 10, 2025
eb7f48c
problem was basically set up wrong so the true kl and bmd was wrong
htjb Dec 10, 2025
f4b5c03
removing the old tutorials notebook
htjb Dec 10, 2025
81edf5a
removing old example multimodal distribution code
htjb Dec 10, 2025
95ab95b
fiddling with training parameters and batch size
htjb Dec 10, 2025
bee8cb6
dealing with conflict with amster branch
htjb Dec 11, 2025
cb0db8d
jax friendly inverse and forward pass and jit of trainingset
htjb Dec 11, 2025
1fbf124
trying to get the cluster tests working
htjb Dec 15, 2025
f01ef84
better type hinting on base estimator
htjb Dec 16, 2025
4939ad8
allow users to set the number of clusters
htjb Dec 16, 2025
a2f9ebd
easier test case for clusters and a larger realnvp network
htjb Dec 16, 2025
408a17f
workign on save and load functions but something isnt quite right
htjb Dec 17, 2025
adff989
return an array from approximate bounds rather than a tuple
htjb Dec 17, 2025
06840d0
a working save and load function for the NICE model
htjb Dec 17, 2025
7b1a21b
surpress some orbax warnings and add save and load tests for nice
htjb Dec 17, 2025
3109b4c
addign a functioning save and load to the realnvp class
htjb Dec 17, 2025
4c56bd7
a test for saving and loading of realnvp
htjb Dec 17, 2025
58275d1
custome extensions on save fiels
htjb Dec 18, 2025
2b21364
more consistent save and load functions for kdes
htjb Dec 18, 2025
92b5088
test save and load of kdes and filepath fixes
htjb Dec 18, 2025
6bc66c4
saving and passign theta ranges
htjb Dec 18, 2025
ebb8c23
working version of save for all estimators and cluster save test
htjb Dec 18, 2025
5a563b1
bug fix in kde save function
htjb Dec 18, 2025
e7d4af0
writing bijective transformation for kde with some help from an llm
htjb Dec 18, 2025
4bf3767
jitting fucntions in __call__ for kde
htjb Dec 18, 2025
f40e8a2
removing the todo file
htjb Jan 5, 2026
6f12071
removign the old docs
htjb Jan 8, 2026
a782000
removing old readme
htjb Jan 8, 2026
8fc95b5
laying out new docs
htjb Jan 8, 2026
bcb0c96
Merge branch 'v2' of github.com:htjb/margarine into v2
htjb Jan 8, 2026
941c02f
optional doc dependencies
htjb Jan 8, 2026
7aa8b9a
set readthedocs up for mkdocs
htjb Jan 8, 2026
008e868
adding an api reference'
htjb Jan 8, 2026
1241e35
fixing the api reference in the docs
htjb Jan 8, 2026
efab237
basic example
htjb Jan 8, 2026
c5d6a15
modifying badge color
htjb Jan 8, 2026
d101290
adding a roadmap and outline of tutorials doc
htjb Jan 8, 2026
83ea332
minor bug fix in doc string
htjb Jan 8, 2026
e77ced1
exapaning the readme intro
htjb Jan 9, 2026
1215305
adding a plan to diversive normalisation techniques
htjb Jan 9, 2026
8a97560
saving the margarine version with the flow and checking against it wh…
htjb Jan 9, 2026
9931fb2
a basic tutorial
htjb Jan 9, 2026
2c84a23
more details in the docs
htjb Jan 9, 2026
c8de4cb
code highlighting
htjb Jan 9, 2026
240952b
correctly reference last v1
htjb Jan 9, 2026
4432d7c
better type hinting on load return
htjb Jan 9, 2026
e454836
more tutorials
htjb Jan 9, 2026
cf08cef
Merge branch 'master' into v2
htjb Jan 9, 2026
497a01f
updating image link in tutorial
htjb Jan 9, 2026
4a9ec33
modifying check version script for md readme
htjb Jan 9, 2026
c4b2d75
adding a note on statistic errors
htjb Jan 9, 2026
8e9dd77
more permissive type hinting on cluster
htjb Jan 9, 2026
0900141
minor changes to readme
htjb Jan 9, 2026
47c87c7
bug fix in log prob calculations
htjb Jan 9, 2026
93d51af
type in a comment
htjb Jan 9, 2026
d1b86b0
better type hinting in train_test_split
htjb Jan 9, 2026
73c4bde
ooopa think anpther bug fix
htjb Jan 9, 2026
bc25cdc
typo in comment
htjb Jan 9, 2026
0cb9d96
spelt calculated wrong again
htjb Jan 9, 2026
7e0d457
removing a redundant duplicate line
htjb Jan 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
*/__pycache__
*/.DS_Store
**/__pycache__
**/.DS_Store
margarine.egg-info/
.pytest_cache/
.ruff_cache/
Expand Down
14 changes: 5 additions & 9 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@ build:
tools:
python: "3.11"

# Build documentation in the docs/ directory with Sphinx
sphinx:
builder: html
configuration: docs/source/conf.py
fail_on_warning: false
# Build documentation in the docs/ directory with MkDocs
mkdocs:
configuration: mkdocs.yaml

# Optionally build your docs in additional formats such as PDF and ePub
formats:
- htmlzip

# Optionally set the version of Python and requirements required to build your docs
python:
install:
- method : pip
path : .
- requirements: docs/requirements.txt
extra_requirements :
- docs
142 changes: 142 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# margarine: you won't believe it's not your posterior samples!

### Marginal Bayesian Statistics

**Authors:** Harry T.J. Bevins
**Version:** 2.0.0
**Homepage:** https://github.com/htjb/margarine
**Documentation:** https://margarine.readthedocs.io/

[![Documentation Status](https://readthedocs.org/projects/margarine/badge/?version=latest)](https://margarine.readthedocs.io/en/latest/?badge=latest) [![arXiv:2205.12841](http://img.shields.io/badge/astro.IM-arXiv%3A2205.12841-DCFF87.svg)](https://arxiv.org/abs/2205.12841)
[![arXiv:2305.02930](http://img.shields.io/badge/astro.IM-arXiv%3A2305.02930-DCFF87.svg)](https://arxiv.org/abs/2305.02930)
[![arXiv:2207.11457](http://img.shields.io/badge/astro.IM-arXiv%3A2207.11457-DCFF87.svg)](https://arxiv.org/abs/2207.11457)
[![PyPI version](https://badge.fury.io/py/margarine.svg)](https://badge.fury.io/py/margarine)
[![Licence: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)


`margarine` provides a suite of density estimation tools including KDEs, normalizing flows like NICE and RealNVP as well as a novel method for improved performance on multimodal distributions.

The code can be used to:

- Emulate posterior distributions from weightened samples (e.g. MCMC, nested sampling)
- Build non-trivial priors from samples
- Perform density estimation tasks in general machine learning applications
- Emulate correctly normalised marginal likelihoods
- Calcualte statistics like the KL divergence between different density estimators and marginal model dimensionalities.
Comment thread
htjb marked this conversation as resolved.
Outdated

---

## Installation

From version 2.0.0 margarine moved to JAX for improved performance. Older versions (1.x.x) using TensorFlow are still available via pip with the last release being 1.4.2.

Install from Git:

```bash
git clone https://github.com/htjb/margarine.git # or use SSH
cd margarine
pip install .
```

Or via pip:

```bash
pip install margarine
```

Note: pip may not always give the latest version.

---

## Getting Started

All of the density estimators in `margarine` have a common interface and set of methods including `train()`, `sample()`, `log_prob()`, `log_like()`, `save()` and `load()`. The below example shows how to train a RealNVP and generate samples.

```python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from margarine.estimators.realnvp import RealNVP

nsamples = 5000
key = jax.random.PRNGKey(0)

original_samples = jax.random.multivariate_normal(
key,
mean=jnp.array([0.0, 0.0]),
cov=jnp.array([[1.0, 0.8], [0.8, 1.0]]),
shape=(nsamples,),
)

weights = jnp.ones(len(original_samples))

realnvp_estimator = RealNVP(
original_samples,
weights=weights,
in_size=2,
hidden_size=50,
num_layers=6,
num_coupling_layers=6,
)

key, subkey = jax.random.split(key)

realnvp_estimator.train(
subkey,
learning_rate=1e-3,
epochs=2000,
patience=50,
batch_size=1000,
)

generated_samples = realnvp_estimator.sample(key, num_samples=nsamples)

plt.scatter(
original_samples[:, 0], original_samples[:, 1], alpha=0.5, label="Original Samples"
)
plt.scatter(
generated_samples[:, 0], generated_samples[:, 1], alpha=0.5, label="Generated Samples"
)
plt.legend()
plt.title("RealNVP: Original vs Generated Samples")
plt.xlabel("X1")
plt.ylabel("X2")
plt.show()
```

for more details see the documentation.

---

## Documentation

Available at: https://margarine.readthedocs.io/. To build locally:

```bash
pip install ".[docs]"
mkdocs serve
```


---

## Licence & Citation

Licensed under MIT.

If used for academic work, please cite:

* Main paper: https://arxiv.org/abs/2205.12841
* MaxEnt22 proceedings: https://arxiv.org/abs/2207.11457
* Piecewise Normalising Flows Paper: https://arxiv.org/abs/2305.02930

---

## Contributing

Contributions and feature suggestions welcome. Open an issue to report bugs or discuss ideas. See `CONTRIBUTING.md` for details.

The future goals of the project are outlined in `ROADMAP.md`.


161 changes: 0 additions & 161 deletions README.rst

This file was deleted.

18 changes: 18 additions & 0 deletions ROADMAP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Roadmap

The roadmap outlines major planned additions/changes to the code.

## Current Development

- **Masked Autoregressive Flows**: Currently being implemented in [maf](https://github.com/htjb/margarine/tree/maf)

## Planned Additions

- **Alternate Normalisation**: Implement standardization as an alternate to the existing gaussianization. Similar to [margarine_unbounded](https://github.com/mrosep/margarine_unbounded).
- **Conditional Flows**: Support for SBI and to build $\beta$-flows (see [here](https://arxiv.org/abs/2411.17663)).
- **Evidence Calculations**: Integration of [floZ](https://arxiv.org/abs/2404.12294) style evidence calculations.
- **Neural Spline Flows**: Implementation of NSF for improved flexibility.

## Future Exploration

- **Diffusion Models**: Explore addition of diffusion models as alternative density estimators.
Loading
Loading