Skip to content

Commit

Permalink
example, sde prior fix, resnet fix, moons
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Sep 10, 2024
1 parent 306f9a8 commit 6a1a290
Show file tree
Hide file tree
Showing 13 changed files with 300 additions and 74 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<h1 align='center'>sbgm</h1>
<h2 align='center'>Score-based Diffusion models in JAX</h2>
<h2 align='center'>Score-Based Diffusion Models in JAX</h2>

Implementation and extension of
* [Score-Based Generative Modeling through Stochastic Differential Equations (Song++20)](https://arxiv.org/abs/2011.13456)
Expand All @@ -17,6 +17,8 @@ in `jax` and `equinox`.

Diffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data.

Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).

-----

<p align="center">
Expand All @@ -28,8 +30,6 @@ Diffusion models are deep hierarchical models for data that use neural networks

-----

Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).

For any SDE of the form

$$
Expand Down
24 changes: 13 additions & 11 deletions configs/moons.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,43 @@
def moons_config():
config = ml_collections.ConfigDict()

config.seed = 0
config.seed = 0

# Data
config.dataset_name = "moons"
config.dataset_name = "moons"

# Model
config.model = model = ml_collections.ConfigDict()
model.model_type = "mlp"
model.model.width_size = 128
model.depth = 2
model.width_size = 128
model.depth = 5
model.activation = jax.nn.tanh
model.dropout_p = 0.1

# SDE
config.sde = sde = ml_collections.ConfigDict()
sde.sde = "VP"
sde.t1 = 4.
sde.t0 = 0.
sde.dt = 0.02
sde.beta_integral = lambda t: t ** 2.
sde.t1 = 4.
sde.dt = 0.1
sde.beta_integral = lambda t: t
sde.N = 1000

# Sampling
config.use_ema = True
config.sample_size = 64
config.exact_logp = True
config.ode_sample = True
config.eu_sample = True
config.ode_sample = False
config.eu_sample = False

# Optimisation hyperparameters
config.start_step = 0
config.n_steps = 200_000
config.batch_size = 256
config.n_steps = 50_000
config.batch_size = 512
config.print_every = 5_000
config.lr = 1e-4
config.opt = "adabelief"
config.opt_kwargs = {}

# Other
config.cmap = "PiYG"
Expand Down
33 changes: 22 additions & 11 deletions data/moons.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
from .utils import ScalerDataset, _InMemoryDataLoader


def key_to_seed(key):
return int(jnp.asarray(jr.key_data(key)).sum())

def moons(key):
key_train, key_valid = jr.split(key)
data_shape = (2,)
context_shape = (1,)
context_shape = None
parameter_dim = 1

Xt, Yt = make_moons(
5_000, noise=0.05, random_state=int(key_train.sum())
40_000, noise=0.05, random_state=key_to_seed(key_train)
)
Xv, Yv = make_moons(
5_000, noise=0.05, random_state=int(key_valid.sum())
40_000, noise=0.05, random_state=key_to_seed(key_valid)
)

min = Xt.min()
Expand All @@ -29,30 +33,37 @@ def moons(key):
valid_data = (Xv - mean) / std

train_dataloader = _InMemoryDataLoader(
jnp.asarray(train_data),
jnp.asarray(Yt)[:, jnp.newaxis],
X=jnp.asarray(train_data),
A=jnp.asarray(Yt)[:, jnp.newaxis],
key=key_train
)
valid_dataloader = _InMemoryDataLoader(
jnp.asarray(valid_data),
jnp.asarray(Yv)[:, jnp.newaxis],
X=jnp.asarray(valid_data),
A=jnp.asarray(Yv)[:, jnp.newaxis],
key=key_valid
)

class _Scaler:
forward: callable
reverse: callable
def __init__(self):
def __init__(self, a, b):
# [0, 1] -> [-1, 1]
self.forward = lambda x: x
self.forward = lambda x: 2. * (x - a) / (b - a) - 1.
# [-1, 1] -> [0, 1]
self.reverse = lambda y: y
self.reverse = lambda y: (y + 1.) * 0.5 * (b - a) + a

def label_fn(key, n):
Q = None
A = jr.choice(key, jnp.array([0., 1.]), (n,))[:, jnp.newaxis]
return Q, A

return ScalerDataset(
name="moons",
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
data_shape=data_shape,
context_shape=context_shape,
scaler=_Scaler()
parameter_dim=parameter_dim,
scaler=_Scaler(a=Xt.min(), b=Xt.max()),
label_fn=label_fn
)
203 changes: 203 additions & 0 deletions examples/moons.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sbgm"
version = "0.0.12"
version = "0.0.14"
description = "Score-based Diffusion models in JAX."
readme = "README.md"
requires-python ="~=3.12"
Expand Down
8 changes: 1 addition & 7 deletions sbgm/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,10 @@ def samples_onto_ax(_X, fig, ax, vs, cmap):
pass


def plot_metrics(train_losses, valid_losses, dets, step, exp_dir):
def plot_metrics(train_losses, valid_losses, step, exp_dir):
if step != 0:
fig, ax = plt.subplots(1, 1, figsize=(8., 4.))
ax.loglog(train_losses)
ax.loglog(valid_losses)
plt.savefig(os.path.join(exp_dir, "loss.png"))
plt.close()

if dets is not None:
plt.figure()
plt.semilogy(dets)
plt.savefig(os.path.join(exp_dir, "dets.png"))
plt.close()
5 changes: 3 additions & 2 deletions sbgm/_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def log_likelihood(
) -> Tuple[Array, Array]:
""" Compute log-likelihood by solving ODE """

model = eqx.tree_inference(model, True)
model = eqx.nn.inference_mode(model, True)

reverse_sde = sde.reverse(model, probability_flow=True)

Expand Down Expand Up @@ -113,7 +113,8 @@ def ode(
adjoint=dfx.DirectAdjoint()
)
(z,), (delta_log_likelihood,) = sol.ys
log_p_y = sde.prior_log_prob(z) + delta_log_likelihood # NOTE: sum() of prior log prob?
p_z = sde.prior_log_prob(z).sum()
log_p_y = p_z + delta_log_likelihood
return z, log_p_y


Expand Down
50 changes: 26 additions & 24 deletions sbgm/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def make_step(
opt_state: OptState,
opt_update: TransformUpdateFn
) -> Tuple[Array, Model, Key, OptState]:
model = eqx.tree_inference(model, False)
model = eqx.nn.inference_mode(model, False)
loss_fn = eqx.filter_value_and_grad(batch_loss_fn)
loss, grads = loss_fn(model, sde, x, q, a, key)
updates, opt_state = opt_update(grads, opt_state)
Expand All @@ -123,7 +123,7 @@ def evaluate(
a: Array,
key: Key
) -> Array:
model = eqx.tree_inference(model, True)
model = eqx.nn.inference_mode(model, True)
loss = batch_loss_fn(model, sde, x, q, a, key)
return loss

Expand All @@ -146,25 +146,14 @@ def train(
# Sharding of devices to run on
sharding: Optional[jax.sharding.Sharding] = None,
# Location to save model, figs, .etc in
save_dir: Optional[str] = None
save_dir: Optional[str] = None,
plot_train_data: bool = False
):
print(f"Training SGM with {config.sde.sde} SDE on {config.dataset_name} dataset.")

# Experiment and image save directories
exp_dir, img_dir = make_dirs(save_dir, config)

# Plot SDE over time
plot_sde(sde, filename=os.path.join(exp_dir, "sde.png"))

# Plot a sample of training data
plot_train_sample(
dataset,
sample_size=config.sample_size,
cmap=config.cmap,
vs=None,
filename=os.path.join(img_dir, "data.png")
)

# Model and optimiser save filenames
model_filename = os.path.join(
exp_dir, f"sgm_{dataset.name}_{config.model.model_type}.eqx"
Expand All @@ -173,6 +162,19 @@ def train(
exp_dir, f"state_{dataset.name}_{config.model.model_type}.obj"
)

# Plot SDE over time
plot_sde(sde, filename=os.path.join(exp_dir, "sde.png"))

# Plot a sample of training data
if plot_train_data:
plot_train_sample(
dataset,
sample_size=config.sample_size,
cmap=config.cmap,
vs=None,
filename=os.path.join(img_dir, "data.png")
)

# Reload optimiser and state if so desired
opt = get_opt(config)
if not reload_opt_state:
Expand All @@ -194,7 +196,6 @@ def train(
valid_total_size = 0
train_losses = []
valid_losses = []
dets = []

if config.use_ema:
ema_model = deepcopy(model)
Expand Down Expand Up @@ -258,13 +259,14 @@ def train(
ode_sample = jax.vmap(sample_fn)(sample_keys, Q, A)

# Sample images and plot
plot_model_sample(
eu_sample,
ode_sample,
dataset,
config,
filename=os.path.join(img_dir, f"samples_{step:06d}"),
)
if config.eu_sample or config.ode_sample:
plot_model_sample(
eu_sample,
ode_sample,
dataset,
config,
filename=os.path.join(img_dir, f"samples_{step:06d}"),
)

# Save model
save_model(
Expand All @@ -280,6 +282,6 @@ def train(
)

# Plot losses etc
plot_metrics(train_losses, valid_losses, dets, step, exp_dir)
plot_metrics(train_losses, valid_losses, step, exp_dir)

return model
2 changes: 1 addition & 1 deletion sbgm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_model(
depth=config.model.depth,
activation=config.model.activation,
dropout_p=config.model.dropout_p,
q_dim=parameter_dim,
a_dim=parameter_dim,
key=model_key
)
if model_type == "CCT":
Expand Down
35 changes: 24 additions & 11 deletions sbgm/models/_mlp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Callable, Union
from typing import Tuple, Callable, Union, Optional
import jax
import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -34,26 +34,31 @@ class ResidualNetwork(eqx.Module):
dropouts: Tuple[eqx.nn.Dropout]
_out: Linear
activation: Callable
a_dim: Optional[int] = None

def __init__(
self,
in_size: int,
width_size: int,
depth: int,
q_dim: int,
activation: Callable,
depth: Optional[int] = None,
a_dim: Optional[int] = None,
activation: Callable = jax.nn.tanh,
dropout_p: float = 0.,
*,
key: Key
):
""" Time-embedding may be necessary """
in_key, *net_keys, out_key = jr.split(key, 2 + depth)
self._in = Linear(
in_size + q_dim + 1, width_size, key=in_key
in_size + a_dim + 1 if a_dim is not None else in_size + 1,
width_size,
key=in_key
)
layers = [
Linear(
width_size + q_dim + 1, width_size, key=_key
width_size + a_dim + 1 if a_dim is not None else width_size + 1,
width_size,
key=_key
)
for _key in net_keys
]
Expand All @@ -67,23 +72,31 @@ def __init__(
self.layers = tuple(layers)
self.dropouts = tuple(dropouts)
self.activation = activation
self.a_dim = a_dim

def __call__(
self,
t: Union[float, Array],
x: Array,
y: Array,
q: Array,
a: Array,
*,
key: Key = None
) -> Array:
t = jnp.atleast_1d(t)
xyt = jnp.concatenate([x, y, t])
h0 = self._in(xyt)
if a is not None and self.a_dim is not None:
xat = jnp.concatenate([x, a, t])
else:
xat = jnp.concatenate([x, t])
h0 = self._in(xat)
h = h0
for l, d in zip(self.layers, self.dropouts):
# Condition on time at each layer
hyt = jnp.concatenate([h, y, t])
h = l(hyt)
if a is not None and self.a_dim is not None:
hat = jnp.concatenate([h, a, t])
else:
hat = jnp.concatenate([h, t])
h = l(hat)
h = d(h, key=key)
h = self.activation(h)
o = self._out(h0 + h)
Expand Down
2 changes: 1 addition & 1 deletion sbgm/sde/_subvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ def prior_sample(self, key, shape):
return jr.normal(key, shape)

def prior_log_prob(self, z):
return jax.vmap(_get_log_prob_fn(scale=1.))(z)
return _get_log_prob_fn(scale=1.)(z)


Loading

0 comments on commit 6a1a290

Please sign in to comment.