Skip to content

Version 2#68

Merged
htjb merged 136 commits into
masterfrom
v2
Jan 9, 2026
Merged

Version 2#68
htjb merged 136 commits into
masterfrom
v2

Conversation

@htjb
Copy link
Copy Markdown
Owner

@htjb htjb commented Dec 1, 2025

Description

This PR is designed to update margarine to use JAX and include several different normalising flows.

Impact on issues

Fixes #67 and #29.
#8 is no longer relevant.
I'm leaving error calculation on marginal statistics up to the user, closing #28. Although I will add a discussion to the documentation.
Should also fix #56

Key changes

  • Added NICE and RealNVP density estimators.
  • The code base is now written in JAX.
  • Density estimators are now written in JAX and flax.nnx.
  • Added an abstract BaseDensityEstimator that all density estimators inherit from. This means that each density estimator has an expected set of methods.
  • The implementation of Piecewise Normalising Flows in the old clusterMAF class has been rewritten into the cluster class in margarine/estimators/clustered.py. It takes advantage of the common API for each density estimator to allow users to build Piecewise NFs with any other implemented NF architecture (e.g. users can now build RealNVP PNFs, NICE PNFs, and even Piecewise KDEs).
  • Restructured the files so that the base class is kept in margarine/base/, density estimators are kept in margarine/estimators/ and utilities are kept in margarine/utils/.
  • Added a JAX Kmeans implementation since jax.scipy.stats doesn't have one and it is needed for Piecewise Normalising Flows.
  • New tests have been written.
  • Added load and save functions for NICE, RealNVP, cluster and KDE.
  • Added __call__ function for KDE. To transform samples from the unit hypercube on to the KDE you need conditional inverse transform sampling and this needs to be reimplemented in JAX.
  • A ROADMAP.md has been added with planned features for future version.
  • Documentation has been updated.

Checklist:

  • I have performed a self-review of my own code
  • New and existing unit tests pass locally with my changes (python -m pytest)
  • I have added tests that prove my fix is effective or that my feature works
  • I have appropriately incremented the semantic version number in both README.rst and margarine/_version.py

htjb added 30 commits November 26, 2025 16:04
…other density estimators not just MAFs as in version 1
@htjb htjb marked this pull request as ready for review January 9, 2026 15:07
@htjb htjb requested a review from Copilot January 9, 2026 15:08
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request represents a major version update (1.4.2 → 2.0.0) that migrates margarine from TensorFlow to JAX. The refactoring introduces significant architectural improvements and adds new density estimators.

Key changes:

  • Complete migration from TensorFlow to JAX/Flax for improved performance and GPU acceleration
  • Added NICE and RealNVP normalizing flow implementations
  • Introduced a BaseDensityEstimator abstract class providing a common API across all estimators
  • Restructured codebase with modular organization (base/, estimators/, utils/)
  • Rewrote clustering implementation to support any density estimator type (Piecewise Normalizing Flows)
  • Added JAX-based K-means clustering implementation
  • Updated documentation to use MkDocs instead of Sphinx

Reviewed changes

Copilot reviewed 40 out of 52 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
tests/test_utils.py New tests for utility functions (transformations, bounds)
tests/test_importance_sampling.py Rewritten importance sampling tests using JAX and RealNVP
tests/test_estimators.py New comprehensive tests for NICE, RealNVP, and KDE estimators
tests/test_cluster.py New tests for clustered/piecewise normalizing flows
margarine/statistics.py New module for KL divergence, model dimensionality, and integration
margarine/utils/utils.py JAX implementations of transformation and bounds estimation utilities
margarine/utils/kmeans.py JAX-based K-means clustering implementation
margarine/base/baseflow.py Abstract base class defining common density estimator interface
margarine/estimators/realnvp.py RealNVP normalizing flow implementation in JAX/Flax
margarine/estimators/nice.py NICE normalizing flow implementation in JAX/Flax
margarine/estimators/kde.py KDE implementation with JAX support and conditional sampling
margarine/estimators/clustered.py Piecewise NF wrapper supporting any base estimator
pyproject.toml Updated dependencies from TensorFlow to JAX/Flax/Optax
mkdocs.yaml New MkDocs configuration replacing Sphinx
docs/tutorials.md New comprehensive tutorials for v2.0.0
README.md Rewritten README with updated examples and information
margarine/_version.py Version bump to 2.0.0

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread margarine/statistics.py Outdated
Comment thread margarine/utils/utils.py
Comment thread margarine/estimators/realnvp.py
Comment thread margarine/estimators/nice.py
Comment thread margarine/estimators/realnvp.py
Comment thread margarine/estimators/nice.py Outdated
Comment thread margarine/estimators/kde.py
Comment thread margarine/estimators/nice.py Outdated
Comment thread margarine/estimators/nice.py
Comment thread margarine/estimators/realnvp.py
@htjb htjb merged commit e130b20 into master Jan 9, 2026
5 checks passed
@htjb htjb deleted the v2 branch January 9, 2026 16:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

log_like calculation in clusterMAF Unable to save flow.loss_history and flow.test_loss_history in MAF objects

2 participants