-
Notifications
You must be signed in to change notification settings - Fork 10
Version 2 #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Version 2 #68
Changes from 128 commits
Commits
Show all changes
136 commits
Select commit
Hold shift + click to select a range
b2fb557
adding a base class to build density estimators on top of
htjb fe025a1
tighter prototyping of the functions
htjb c31ad44
util functions
htjb aadb5dc
beginning to lay out a kde class as a starting point for v2
htjb 2ffd42b
adding forward and inverse transformations
htjb c6120a7
realised I can use tensorflow probability with the jax backend
htjb 02e0198
fiddling with the parameter order in the sampling function
htjb da3b55a
adding in a log prob function with correction
htjb fc7ea0e
kde reimplementation
htjb 32a8dec
removing old margarine kde class
htjb 8210575
hmmm implementing the conditional inverse transform sampling is trick…
htjb ce08c15
adding a jax test train split function
htjb cbc25d0
removing now redundatn processing code
htjb 2596472
rearranging the file structure
htjb be6ea2a
organising utils into seperate file
htjb 26d7204
an implementation of kmeans in jax
htjb e885e8f
jax implementation of the silhouette score
htjb 2157315
making sure that the bounds function returns (min, max) and fixing ty…
htjb 28f60e1
adding a generic cluster code to build piecewise estimators from any …
htjb 545645b
better flagging of pycache and ds_store
htjb 7198547
make base class and abstract class with required methods
htjb ac0b2a2
doesnt need to inherit from the base class
htjb 3911fc7
dont need to init base class here
htjb 752e22f
implementation of the NICE normalising flow
htjb c790283
bug fix in log prob calculation for kde
htjb f701de6
depends on flax
htjb 89f1364
better variable name for hidden size
htjb a62c5f6
updating the log_prob_under_nice function name so it doesnt clash wit…
htjb 916d38a
better reporting of the loss
htjb 32cb9db
realNVP implementation
htjb e537711
correcting a bug in the bounds estimation
htjb d82bdbb
calcualte theta ranges seperately for each cluster so that they are w…
htjb 53a9e3c
some doc string improvements
htjb 356cfda
adding a more detailed discription of the origin of the approximate b…
htjb f44529a
new margarine strap line
htjb ab5f17e
bumping version number because im excited
htjb b5b8d6b
fiddling with the pyproject.toml
htjb c9af1c9
merging
htjb 1475ab6
missing comma
htjb af61154
need tfp nightly for compatibility with latest jax
htjb 2fe86ba
laying out new statistics code
htjb ded5d79
code to calcualte kl and bmd
htjb 7a41ff4
jit the log prob functions
htjb 5512bcf
removing old maf class so i can rewrite
htjb 2814a4a
starting to lay out the base of the maf class
htjb 4c16802
adding a to do list briefly so i can keep track
htjb be26f45
build the mades, masks and inverse and forward passes
htjb 7f20c84
adding in activations after first layers
htjb 547c375
pretty sure this is set up correctly
htjb 7659e96
okay i think this maf is pretty well set up and optimized
htjb a660b1a
updating the todo
htjb c5c7b48
modifying the clustered_distribution example
htjb a6d67be
removing some old code
htjb 39bf3e1
removing old test files
htjb 38d1b51
importing the base from the correct file
htjb 5640fcb
more stable kernel initialisation
htjb 7ef9233
removing ndims in base class
htjb febb392
adding tests for kde and some of the util functions
htjb 220385e
better name for utils test and formatting on tests
htjb f713279
tighter constraints on kl and bmd accuracy (I think something wrong w…
htjb b251d85
testing on a more straight forward problem
htjb 69d6521
i think the latest versions of flax are only tested on 3.11 and above
htjb bc78dd9
trying to use stochasticity of training to set atol and rtol on kl an…
htjb a1512c3
playing with the tests. most pass now just trouble with maf training …
htjb 75babff
type in instance check in loglike
htjb bc9b154
bug fixing the permutations in the maf but still needs some tinkering
htjb 6f526e3
better initialisation of final layer in the mades and some tinkering
htjb 9ce214a
removing maf stuff from v2 branch after splitting into maf branch
htjb 6d24c89
removing maf tests from v2 branch
htjb c535866
testing the importance sampling functionality
htjb dafb13f
batched training for nice and realnvp
htjb 8fc3cc7
rough test code for cluster class
htjb 93f1e2c
setting batch size for better performance on toy problems
htjb 9fb9e33
specifying max cluster size
htjb a009009
seperate theta bounds for each flow causes issues when evaluating log…
htjb 9d3257e
playing with the flow settings but some inf is creeping in somewhere …
htjb e31a247
hmmm seems the inf is the true kl
htjb eb7f48c
problem was basically set up wrong so the true kl and bmd was wrong
htjb f4b5c03
removing the old tutorials notebook
htjb 81edf5a
removing old example multimodal distribution code
htjb 95ab95b
fiddling with training parameters and batch size
htjb bee8cb6
dealing with conflict with amster branch
htjb cb0db8d
jax friendly inverse and forward pass and jit of trainingset
htjb 1fbf124
trying to get the cluster tests working
htjb f01ef84
better type hinting on base estimator
htjb 4939ad8
allow users to set the number of clusters
htjb a2f9ebd
easier test case for clusters and a larger realnvp network
htjb 408a17f
workign on save and load functions but something isnt quite right
htjb adff989
return an array from approximate bounds rather than a tuple
htjb 06840d0
a working save and load function for the NICE model
htjb 7b1a21b
surpress some orbax warnings and add save and load tests for nice
htjb 3109b4c
addign a functioning save and load to the realnvp class
htjb 4c56bd7
a test for saving and loading of realnvp
htjb 58275d1
custome extensions on save fiels
htjb 2b21364
more consistent save and load functions for kdes
htjb 92b5088
test save and load of kdes and filepath fixes
htjb 6bc66c4
saving and passign theta ranges
htjb ebb8c23
working version of save for all estimators and cluster save test
htjb 5a563b1
bug fix in kde save function
htjb e7d4af0
writing bijective transformation for kde with some help from an llm
htjb 4bf3767
jitting fucntions in __call__ for kde
htjb f40e8a2
removing the todo file
htjb 6f12071
removign the old docs
htjb a782000
removing old readme
htjb 8fc95b5
laying out new docs
htjb bcb0c96
Merge branch 'v2' of github.com:htjb/margarine into v2
htjb 941c02f
optional doc dependencies
htjb 7aa8b9a
set readthedocs up for mkdocs
htjb 008e868
adding an api reference'
htjb 1241e35
fixing the api reference in the docs
htjb efab237
basic example
htjb c5d6a15
modifying badge color
htjb d101290
adding a roadmap and outline of tutorials doc
htjb 83ea332
minor bug fix in doc string
htjb e77ced1
exapaning the readme intro
htjb 1215305
adding a plan to diversive normalisation techniques
htjb 8a97560
saving the margarine version with the flow and checking against it wh…
htjb 9931fb2
a basic tutorial
htjb 2c84a23
more details in the docs
htjb c8de4cb
code highlighting
htjb 240952b
correctly reference last v1
htjb 4432d7c
better type hinting on load return
htjb e454836
more tutorials
htjb cf08cef
Merge branch 'master' into v2
htjb 497a01f
updating image link in tutorial
htjb 4a9ec33
modifying check version script for md readme
htjb c4b2d75
adding a note on statistic errors
htjb 8e9dd77
more permissive type hinting on cluster
htjb 0900141
minor changes to readme
htjb 47c87c7
bug fix in log prob calculations
htjb 93d51af
type in a comment
htjb d1b86b0
better type hinting in train_test_split
htjb 73c4bde
ooopa think anpther bug fix
htjb bc25cdc
typo in comment
htjb 0cb9d96
spelt calculated wrong again
htjb 7e0d457
removing a redundant duplicate line
htjb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| */__pycache__ | ||
| */.DS_Store | ||
| **/__pycache__ | ||
| **/.DS_Store | ||
| margarine.egg-info/ | ||
| .pytest_cache/ | ||
| .ruff_cache/ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| # margarine: you won't believe it's not your posterior samples! | ||
|
|
||
| ### Marginal Bayesian Statistics | ||
|
|
||
| **Authors:** Harry T.J. Bevins | ||
| **Version:** 2.0.0 | ||
| **Homepage:** https://github.com/htjb/margarine | ||
| **Documentation:** https://margarine.readthedocs.io/ | ||
|
|
||
| [](https://margarine.readthedocs.io/en/latest/?badge=latest) [](https://arxiv.org/abs/2205.12841) | ||
| [](https://arxiv.org/abs/2305.02930) | ||
| [](https://arxiv.org/abs/2207.11457) | ||
| [](https://badge.fury.io/py/margarine) | ||
| [](https://opensource.org/licenses/MIT) | ||
|
|
||
|
|
||
| `margarine` provides a suite of density estimation tools including KDEs, normalizing flows like NICE and RealNVP as well as a novel method for improved performance on multimodal distributions. | ||
|
|
||
| The code can be used to: | ||
|
|
||
| - Emulate posterior distributions from weightened samples (e.g. MCMC, nested sampling) | ||
| - Build non-trivial priors from samples | ||
| - Perform density estimation tasks in general machine learning applications | ||
| - Emulate correctly normalised marginal likelihoods | ||
| - Calcualte statistics like the KL divergence between different density estimators and marginal model dimensionalities. | ||
|
|
||
| --- | ||
|
|
||
| ## Installation | ||
|
|
||
| From version 2.0.0 margarine moved to JAX for improved performance. Older versions (1.x.x) using TensorFlow are still available via pip with the last release being 1.4.2. | ||
|
|
||
| Install from Git: | ||
|
|
||
| ```bash | ||
| git clone https://github.com/htjb/margarine.git # or use SSH | ||
| cd margarine | ||
| pip install . | ||
| ``` | ||
|
|
||
| Or via pip: | ||
|
|
||
| ```bash | ||
| pip install margarine | ||
| ``` | ||
|
|
||
| Note: pip may not always give the latest version. | ||
|
|
||
| --- | ||
|
|
||
| ## Getting Started | ||
|
|
||
| All of the density estimators in `margarine` have a common interface and set of methods including `train()`, `sample()`, `log_prob()`, `log_like()`, `save()` and `load()`. The below example shows how to train a RealNVP and generate samples. | ||
|
|
||
| ```python | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| from margarine.estimators.realnvp import RealNVP | ||
|
|
||
| nsamples = 5000 | ||
| key = jax.random.PRNGKey(0) | ||
|
|
||
| original_samples = jax.random.multivariate_normal( | ||
| key, | ||
| mean=jnp.array([0.0, 0.0]), | ||
| cov=jnp.array([[1.0, 0.8], [0.8, 1.0]]), | ||
| shape=(nsamples,), | ||
| ) | ||
|
|
||
| weights = jnp.ones(len(original_samples)) | ||
|
|
||
| realnvp_estimator = RealNVP( | ||
| original_samples, | ||
| weights=weights, | ||
| in_size=2, | ||
| hidden_size=50, | ||
| num_layers=6, | ||
| num_coupling_layers=6, | ||
| ) | ||
|
|
||
| key, subkey = jax.random.split(key) | ||
|
|
||
| realnvp_estimator.train( | ||
| subkey, | ||
| learning_rate=1e-3, | ||
| epochs=2000, | ||
| patience=50, | ||
| batch_size=1000, | ||
| ) | ||
|
|
||
| generated_samples = realnvp_estimator.sample(key, num_samples=nsamples) | ||
|
|
||
| plt.scatter( | ||
| original_samples[:, 0], original_samples[:, 1], alpha=0.5, label="Original Samples" | ||
| ) | ||
| plt.scatter( | ||
| generated_samples[:, 0], generated_samples[:, 1], alpha=0.5, label="Generated Samples" | ||
| ) | ||
| plt.legend() | ||
| plt.title("RealNVP: Original vs Generated Samples") | ||
| plt.xlabel("X1") | ||
| plt.ylabel("X2") | ||
| plt.show() | ||
| ``` | ||
|
|
||
| for more details see the documentation. | ||
|
|
||
| --- | ||
|
|
||
| ## Documentation | ||
|
|
||
| Available at: https://margarine.readthedocs.io/. To build locally: | ||
|
|
||
| ```bash | ||
| pip install ".[docs]" | ||
| mkdocs serve | ||
| ``` | ||
|
|
||
|
|
||
| --- | ||
|
|
||
| ## Licence & Citation | ||
|
|
||
| Licensed under MIT. | ||
|
|
||
| If used for academic work, please cite: | ||
|
|
||
| * Main paper: https://arxiv.org/abs/2205.12841 | ||
| * MaxEnt22 proceedings: https://arxiv.org/abs/2207.11457 | ||
| * Piecewise Normalising Flows Paper: https://arxiv.org/abs/2305.02930 | ||
|
|
||
| --- | ||
|
|
||
| ## Contributing | ||
|
|
||
| Contributions and feature suggestions welcome. Open an issue to report bugs or discuss ideas. See `CONTRIBUTING.md` for details. | ||
|
|
||
| The future goals of the project are outlined in `ROADMAP.md`. | ||
|
|
||
|
|
||
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Roadmap | ||
|
|
||
| The roadmap outlines major planned additions/changes to the code. | ||
|
|
||
| ## Current Development | ||
|
|
||
| - **Masked Autoregressive Flows**: Currently being implemented in [maf](https://github.com/htjb/margarine/tree/maf) | ||
|
|
||
| ## Planned Additions | ||
|
|
||
| - **Alternate Normalisation**: Implement standardization as an alternate to the existing gaussianization. Similar to [margarine_unbounded](https://github.com/mrosep/margarine_unbounded). | ||
| - **Conditional Flows**: Support for SBI and to build $\beta$-flows (see [here](https://arxiv.org/abs/2411.17663)). | ||
| - **Evidence Calculations**: Integration of [floZ](https://arxiv.org/abs/2404.12294) style evidence calculations. | ||
| - **Neural Spline Flows**: Implementation of NSF for improved flexibility. | ||
|
|
||
| ## Future Exploration | ||
|
|
||
| - **Diffusion Models**: Explore addition of diffusion models as alternative density estimators. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.