Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:margarine: Marginal Bayesian Statistics
:Authors: Harry T.J. Bevins
:Version: 1.4.0
:Version: 1.4.1
:Homepage: https://github.com/htjb/margarine
:Documentation: https://margarine.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion margarine/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.0"
__version__ = "1.4.1"
10 changes: 9 additions & 1 deletion margarine/clustered.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
def train(
self,
epochs: int = 100,
patience: int | None = None,
early_stop: bool = False,
loss_type: str = "sum",
) -> None:
Expand All @@ -290,6 +291,10 @@ def train(
Keyword Args:
epochs (int, optional): The number of iterations to
train the neural networks for. Defaults to 100.
patience (int | None, optional): The number of epochs
with no improvement
on the test loss before early stopping is triggered.
Defaults to 2% of epochs.
early_stop (bool, optional): Whether to implement
early stopping or train for the set number of epochs.
If True, training stops when test loss
Expand All @@ -303,7 +308,10 @@ def train(
"""
for i in range(len(self.flow)):
self.flow[i].train(
epochs=epochs, early_stop=early_stop, loss_type=loss_type
epochs=epochs,
patience=patience,
early_stop=early_stop,
loss_type=loss_type,
)

def log_prob(self, params: np.ndarray) -> np.ndarray:
Expand Down
18 changes: 17 additions & 1 deletion margarine/maf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def gen_mades(self) -> tuple[tfb.Bijector, tfd.TransformedDistribution]:
def train(
self,
epochs: int = 100,
patience: int | None = None,
early_stop: bool = False,
loss_type: str = "sum",
) -> None:
Expand All @@ -228,6 +229,10 @@ def train(
Keyword Args:
epochs (int, optional): The number of iterations to
train the neural networks for. Defaults to 100.
patience (int | None, optional): The number of epochs with
no improvement on the test loss before early
stopping is triggered.
Defaults to 2% of total requested epochs.
early_stop (bool, optional): Whether to implement
early stopping or train for the set number of epochs.
If True, training stops when test loss
Expand All @@ -243,13 +248,21 @@ def train(
raise TypeError("'epochs' is not an integer.")
if type(early_stop) is not bool:
raise TypeError("'early_stop' must be a boolean.")
if patience is not None and type(patience) is not int:
raise TypeError("'patience' must be an integer or None.")

self.epochs = epochs
self.early_stop = early_stop
self.loss_type = loss_type

if patience is None:
self.patience = round((self.epochs / 100) * 2)
else:
self.patience = patience
Comment thread
htjb marked this conversation as resolved.

self.maf = self._training(
self.theta,
self.patience,
self.sample_weights,
self.maf,
self.theta_min,
Expand All @@ -259,6 +272,7 @@ def train(
def _training(
self,
theta: tf.Tensor,
patience: int,
sample_weights: tf.Tensor,
maf: tfd.TransformedDistribution,
theta_min: float | tf.Tensor,
Expand All @@ -273,6 +287,8 @@ def _training(

Args:
theta (tf.Tensor): The samples to train on.
patience (int): The number of epochs with no improvement
on the test loss before early stopping is triggered.
sample_weights (tf.Tensor): The weights associated with the
samples.
maf (tfd.TransformedDistribution): The MAF to be trained.
Expand Down Expand Up @@ -317,7 +333,7 @@ def _training(
minimum_model = maf.copy()
c = 0
if minimum_model:
if c == round((self.epochs / 100) * 2):
if c == patience:
print(
"Early stopped. Epochs used = "
+ str(i)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "margarine"
version = "1.4.0"
version = "1.4.1"
description = "margarine: Posterior Sampling and Marginal Bayesian Statistics "
readme = "README.rst"
authors = [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def d_g(
def test_maf_clustering() -> None:
"""Test clustered MAF marginal statistics calculation."""
bij = clusterMAF(theta, weights=weights)
bij.train(10000, early_stop=True)
bij.train(10000, early_stop=True, patience=400)
file = "saved_maf_cluster.pkl"
bij.save(file)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def check_stats(label: str) -> None:
assert_allclose(stats[label], value, rtol=1, atol=1)

bij = MAF(theta, weights=weights)
bij.train(10000, early_stop=True)
bij.train(10000, early_stop=True, patience=400)

stats_label = ["KL Divergence", "BMD"]

Expand Down Expand Up @@ -87,6 +87,9 @@ def test_maf_kwargs() -> None:
with pytest.raises(TypeError):
MAF(theta, weights=weights)
bij.train(epochs=100, cluster_labels=5)
with pytest.raises(TypeError):
MAF(theta, weights=weights)
bij.train(epochs=100, patience="foo")


def test_maf_save_load() -> None:
Expand Down
Loading