diff --git a/README.md b/README.md index 54c745e..6f3efc7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # margarine: density estimation made easy **Authors:** Harry T.J. Bevins -**Version:** 2.1.0 +**Version:** 2.2.0 **Homepage:** https://github.com/htjb/margarine **Documentation:** https://margarine.readthedocs.io/ diff --git a/margarine/_version.py b/margarine/_version.py index 9aa3f90..8a124bf 100644 --- a/margarine/_version.py +++ b/margarine/_version.py @@ -1 +1 @@ -__version__ = "2.1.0" +__version__ = "2.2.0" diff --git a/margarine/estimators/nice.py b/margarine/estimators/nice.py index eff1831..3fc8a46 100644 --- a/margarine/estimators/nice.py +++ b/margarine/estimators/nice.py @@ -94,7 +94,7 @@ def __init__( layers.append(lambda x: jax.nn.gelu(x)) layers.append( - nnx.Linear(self.hidden_size, self.net_in_size, rngs=nnx_rngs) + nnx.Linear(self.hidden_size, self.pass_size, rngs=nnx_rngs) ) self.mlp = nnx.List( [nnx.Sequential(*layers) for _ in range(self.num_coupling_layers)] diff --git a/margarine/estimators/realnvp.py b/margarine/estimators/realnvp.py index f95c8e3..dda162a 100644 --- a/margarine/estimators/realnvp.py +++ b/margarine/estimators/realnvp.py @@ -113,7 +113,7 @@ def __init__( additive_layers.append( nnx.Linear( self.hidden_size, - self.net_in_size, + self.pass_size, kernel_init=kernel_init, rngs=nnx_rngs, ) @@ -148,7 +148,7 @@ def __init__( scaling_layers.append( nnx.Linear( self.hidden_size, - self.net_in_size, + self.pass_size, kernel_init=kernel_init, rngs=nnx_rngs, ) diff --git a/pyproject.toml b/pyproject.toml index 3415a93..6c9ccd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "margarine" -version = "2.1.0" +version = "2.2.0" description = "margarine: Posterior Sampling and Marginal Bayesian Statistics " readme = "README.md" authors = [ diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 9715356..9a5b457 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -244,6 +244,73 @@ def test_realnvp() -> None: os.remove("realnvp_test.marg") +def test_nice_odd_dimensions() -> None: + """Test NICE estimator with an odd number of dimensions.""" + key = jax.random.PRNGKey(45) + + odd_samples = jax.random.multivariate_normal( + key, + mean=jnp.zeros(3), + cov=jnp.eye(3), + shape=(nsamples,), + ) + + nice_estimator = NICE( + odd_samples, + in_size=3, + hidden_size=50, + num_layers=2, + num_coupling_layers=4, + ) + + # check forward/inverse roundtrip + key, subkey = jax.random.split(key) + z = jax.random.normal(subkey, (100, 3)) + forward_transformed = nice_estimator.forward(z) + inverse_transformed = nice_estimator.inverse(forward_transformed) + error = jnp.mean(jnp.abs(inverse_transformed - z)) + assert error < 1e-3 + + # check log_prob runs without crashing + log_probs = nice_estimator.log_prob_under_NICE(z) + assert log_probs.shape == (100,) + assert jnp.all(jnp.isfinite(log_probs)) + + +def test_realnvp_odd_dimensions() -> None: + """Test RealNVP estimator with an odd number of dimensions.""" + key = jax.random.PRNGKey(46) + + odd_samples = jax.random.multivariate_normal( + key, + mean=jnp.zeros(3), + cov=jnp.eye(3), + shape=(nsamples,), + ) + + realnvp_estimator = RealNVP( + odd_samples, + in_size=3, + hidden_size=50, + num_coupling_layers=4, + ) + + # check forward/inverse roundtrip + key, subkey = jax.random.split(key) + z = jax.random.normal(subkey, (100, 3)) + forward_transformed, log_det = realnvp_estimator.forward( + z, return_log_det=True + ) + inverse_transformed = realnvp_estimator.inverse(forward_transformed) + error = jnp.mean(jnp.abs(inverse_transformed - z)) + assert error < 1e-3 + + # check log_prob runs without crashing + log_probs = realnvp_estimator.log_prob_under_RealNVP(z) + assert log_probs.shape == (100,) + assert jnp.all(jnp.isfinite(log_probs)) + + def test_kde() -> None: """Test KDE estimator.""" key = jax.random.PRNGKey(44)