diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index edb372f..28a4556 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -5,7 +5,7 @@ on: branches: - main tags: - - "v0.0.*" + - "0.0.*" jobs: build: diff --git a/configs/__init__.py b/configs/__init__.py index 445d110..ff72c4a 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -1,20 +1,17 @@ -from .quijote import QuijoteConfig -from .mnist import MNISTConfig -from .cifar10 import CIFAR10Config -from .flowers import FlowersConfig +from .quijote import quijote_config +from .mnist import mnist_config +from .cifar10 import cifar10_config +from .flowers import flowers_config from .moons import MoonsConfig from .dgdm import DgdmConfig -from .grfs import GRFConfig - -Config = QuijoteConfig | MoonsConfig | MNISTConfig | CIFAR10Config | FlowersConfig | GRFConfig +from .grfs import grfs_config __all__ = [ - Config, - QuijoteConfig, + quijote_config, + mnist_config, + cifar10_config, + flowers_config, MoonsConfig, - MNISTConfig, - CIFAR10Config, - FlowersConfig, DgdmConfig, - GRFConfig + grfs_config ] \ No newline at end of file diff --git a/configs/cifar10.py b/configs/cifar10.py index ead391a..199e48b 100644 --- a/configs/cifar10.py +++ b/configs/cifar10.py @@ -1,7 +1,14 @@ -class CIFAR10Config: - seed = 0 +import ml_collections + + +def cifar10_config(): + config = ml_collections.ConfigDict() + + config.seed = 0 + # Data - dataset_name = "cifar10" + config.dataset_name = "cifar10" + # Model # model_type = "Mixer" # model_args = dict( @@ -18,36 +25,45 @@ class CIFAR10Config: # mix_hidden_size=512, # num_blocks=4 # ) - model_type: str = "UNet" - model_args = dict( - is_biggan=False, - dim_mults=[1, 2, 4], - hidden_size=128, - heads=4, - dim_head=64, - dropout_rate=0.3, - num_res_blocks=2, - attn_resolutions=[8, 16, 32], - ) - use_ema = False + config.model = model = ml_collections.ConfigDict() + model.model_type = "UNet" + model.is_biggan = False + model.dim_mults = [1, 1, 1] + model.hidden_size = 128 + model.heads = 4 + model.dim_head = 64 + model.dropout_rate = 0.3 + model.num_res_blocks = 2 + model.attn_resolutions = [8, 16, 32] + model.final_activation = None + # SDE - sde = "VP" - t1 = 8. - t0 = 1e-5 - dt = 0.1 - N = 1000 - beta_integral = lambda t: t + config.sde = sde = ml_collections.ConfigDict() + sde.sde = "VP" + sde.t1 = 8. + sde.t0 = 1e-5 + sde.dt = 0.1 + sde.N = 1000 + sde.beta_integral = lambda t: t + # Sampling - sample_size = 5 - exact_logp = False - ode_sample = True - eu_sample = True + config.use_ema = False + config.sample_size = 5 + config.exact_logp = False + config.ode_sample = True + config.eu_sample = True + # Optimisation hyperparameters - start_step = 0 - n_steps = 1_000_000 - lr = 1e-4 - batch_size = 512 #256 # 256 with UNet - print_every = 1_000 - opt = "adabelief" + config.start_step = 0 + config.n_steps = 1_000_000 + config.lr = 1e-4 + config.batch_size = 512 #256 # 256 with UNet + config.print_every = 1_000 + config.opt = "adabelief" + config.opt_kwargs = {} + config.num_workers = 8 + # Other - cmap = None \ No newline at end of file + config.cmap = None + + return config \ No newline at end of file diff --git a/configs/flowers.py b/configs/flowers.py index cc9adfa..926576a 100644 --- a/configs/flowers.py +++ b/configs/flowers.py @@ -1,38 +1,56 @@ -class FlowersConfig: - seed = 0 +import ml_collections + + +def flowers_config(): + config = ml_collections.ConfigDict() + + config.seed = 0 + # Data - dataset_name = "flowers" - n_pix = 64 + config.dataset_name = "flowers" + config.n_pix = 64 + # Model - model_type: str = "UNet" - model_args = dict( - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=256, - heads=2, - dim_head=64, - dropout_rate=0.3, - num_res_blocks=2, - attn_resolutions=[8, 32, 64] - ) - use_ema = False + config.model = model = ml_collections.ConfigDict() + model.model_type = "UNet" + model.is_biggan = False + model.dim_mults = [1, 1, 1] + model.hidden_size = 256 + model.heads = 2 + model.dim_head = 64 + model.dropout_rate = 0.3 + model.num_res_blocks = 2 + model.attn_resolutions = [8, 32, 64] + model.final_activation = None + # SDE - t1 = 8. - t0 = 1e-5 - dt = 0.1 - beta_integral = lambda t: t + config.sde = sde = ml_collections.ConfigDict() + sde.sde = "VP" + sde.t1 = 8. + sde.t0 = 1e-5 + sde.dt = 0.1 + sde.beta_integral = lambda t: t + sde.N = 1000 # sde: SDE = VPSDE(beta_integral, dt=dt, t0=t0, t1=t1) + # Sampling - sample_size = 5 - exact_logp = False - ode_sample = True - eu_sample = True + config.sample_size = 5 + config.exact_logp = False + config.ode_sample = True + config.eu_sample = True + # Optimisation hyperparameters - start_step = 0 - n_steps = 1_000_000 - lr = 1e-4 - batch_size = 64 #128 #256 - print_every = 1_000 - opt = "adabelief" + config.use_ema = False + config.start_step = 0 + config.n_steps = 1_000_000 + config.lr = 1e-4 + config.batch_size = 64 #128 #256 + config.print_every = 1_000 + config.opt = "adabelief" + config.opt_kwargs = {} + config.num_workers = 8 + # Other - cmap = None \ No newline at end of file + config.cmap = None + + return config diff --git a/configs/grfs.py b/configs/grfs.py index 4431338..da2aa3f 100644 --- a/configs/grfs.py +++ b/configs/grfs.py @@ -1,40 +1,55 @@ -class GRFConfig: - seed = 0 +import ml_collections + + +def grfs_config(): + config = ml_collections.ConfigDict() + + config.seed = 0 + # Data - dataset_name = "grfs" - n_pix = 64 + config.dataset_name = "grfs" + config.n_pix = 64 + # Model - model_type: str = "UNetXY" - model_args = dict( - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=32, - heads=4, - dim_head=64, - dropout_rate=0.3, - num_res_blocks=2, - attn_resolutions=[8, 16, 32] - ) + config.model = model = ml_collections.ConfigDict() + model.model_type = "UNetXY" + model.is_biggan = False + model.dim_mults = [1, 1, 1] + model.hidden_size = 128 + model.heads = 4 + model.dim_head = 64 + model.dropout_rate = 0.3 + model.num_res_blocks = 2 + model.attn_resolutions = [8, 16, 32] + model.final_activation = None + # SDE - sde = "VP" - t1 = 8. - t0 = 1e-5 - dt = 0.1 - beta_integral = lambda t: t - N = 1000 + config.sde = sde = ml_collections.ConfigDict() + sde.sde = "VP" + sde.t1 = 8. + sde.t0 = 1e-5 + sde.dt = 0.1 + sde.N = 1000 + sde.beta_integral = lambda t: t + # Sampling - sample_size = 8 - exact_logp = False - ode_sample = True - eu_sample = True - use_ema = False + config.use_ema = False + config.sample_size = 5 + config.exact_logp = False + config.ode_sample = True + config.eu_sample = True + # Optimisation hyperparameters - start_step = 0 - n_steps = 1_000_000 - lr = 1e-4 - batch_size = 256 - print_every = 1_000 - opt = "adabelief" - num_workers = 8 + config.start_step = 0 + config.n_steps = 1_000_000 + config.lr = 1e-4 + config.batch_size = 256 + config.print_every = 1_000 + config.opt = "adabelief" + config.opt_kwargs = {} + config.num_workers = 8 + # Other - cmap = "coolwarm" \ No newline at end of file + config.cmap = "coolwarm" + + return config \ No newline at end of file diff --git a/configs/mnist.py b/configs/mnist.py index 84b2aff..bd8cf0b 100644 --- a/configs/mnist.py +++ b/configs/mnist.py @@ -1,49 +1,54 @@ -class MNISTConfig: - seed = 0 +import ml_collections + + +def mnist_config(): + config = ml_collections.ConfigDict() + + config.seed = 0 + # Data - dataset_name = "mnist" + config.dataset_name = "mnist" + # Model - model_type: str = "UNet" - model_args = dict( - is_biggan=False, - dim_mults=[1, 1, 1], - hidden_size=32, - heads=4, - dim_head=64, - dropout_rate=0.3, - num_res_blocks=2, - attn_resolutions=[8, 16, 32], - final_activation=None - ) - # model_type = "Mixer" - # model_args = dict( - # patch_size=2, - # hidden_size=512, - # mix_patch_size=512, - # mix_hidden_size=512, - # num_blocks=4 - # ) + config.model = model = ml_collections.ConfigDict() + model.model_type = "UNet" + model.is_biggan = False + model.dim_mults = [1, 1, 1] + model.hidden_size = 32 + model.heads = 4 + model.dim_head = 64 + model.dropout_rate = 0.3 + model.num_res_blocks = 2 + model.attn_resolutions = [8, 16, 32] + model.final_activation = None + # SDE - sde = "VP" - t1 = 1. - t0 = 1e-5 - dt = 0.1 - beta_integral = lambda t: t - N = 1000 + config.sde = sde = ml_collections.ConfigDict() + sde.sde = "VP" + sde.t1 = 1. + sde.t0 = 1e-5 + sde.dt = 0.1 + sde.beta_integral = lambda t: t + sde.N = 1000 + # Sampling - sample_size = 8 - exact_logp = False - ode_sample = True - eu_sample = True - use_ema = False + config.sample_size = 8 + config.exact_logp = False + config.ode_sample = True + config.eu_sample = True + config.use_ema = False + # Optimisation hyperparameters - start_step = 0 - n_steps = 1_000_000 - lr = 1e-4 - batch_size = 256 - print_every = 1_000 - opt = "adabelief" - opt_kwargs = {} - num_workers = 8 + config.start_step = 0 + config.n_steps = 1_000_000 + config.lr = 1e-4 + config.batch_size = 256 + config.print_every = 1_000 + config.opt = "adabelief" + config.opt_kwargs = {} + config.num_workers = 8 + # Other - cmap = "gray_r" \ No newline at end of file + config.cmap = "gray_r" + + return config \ No newline at end of file diff --git a/configs/quijote.py b/configs/quijote.py index 469bfe8..e19119c 100644 --- a/configs/quijote.py +++ b/configs/quijote.py @@ -1,52 +1,56 @@ -import jax.random as jr -import optax +import ml_collections -# from _sde import SDE, VPSDE, SubVPSDE, VESDE -img_dir = "/project/ls-gruen/users/jed.homer/1pt_pdf/little_studies/sgm_with_sde_lib/imgs/" -exp_dir = "/project/ls-gruen/users/jed.homer/1pt_pdf/little_studies/sgm_with_sde_lib/exps/" +def quijote_config(): + config = ml_collections.ConfigDict() + config.seed = 0 -class QuijoteConfig: - key = jr.PRNGKey(0) # Data - dataset_name = "Quijote" - n_pix = 128 + config.dataset_name = "quijote" + config.n_pix = 32 + # Model - model_type: str = "UNet" - model_args = dict( - is_biggan=False, - dim_mults=[1, 2, 4], - hidden_size=128, - heads=2, - dim_head=64, - dropout_rate=0.3, - num_res_blocks=2, - attn_resolutions=[8, 32, 64] - # attn_resolutions=reversed( - # [int(n_pix / (2 ** i)) for i in range(len(3))] - # ) - ) - use_ema = True + config.model = model = ml_collections.ConfigDict() + model.model_type = "UNet" + model.is_biggan = False + model.dim_mults = [1, 1, 1] + model.hidden_size = 128 + model.heads = 4 + model.dim_head = 64 + model.dropout_rate = 0.3 + model.num_res_blocks = 2 + model.attn_resolutions = [8, 32, 64] + model.final_activation = None + # SDE - t1 = 8. - t0 = 1e-5 - dt = 0.1 - beta_integral = lambda t: t + config.sde = sde = ml_collections.ConfigDict() + sde.sde = "VP" + sde.t1 = 8. + sde.t0 = 1e-5 + sde.dt = 0.1 + sde.N = 1000 + sde.beta_integral = lambda t: t # sde: SDE = VPSDE(beta_integral, dt=dt, t0=t0, t1=t1) + # Sampling - sample_size = 5 - exact_logp = False - ode_sample = True - eu_sample = True + config.use_ema = False + config.sample_size = 5 + config.exact_logp = False + config.ode_sample = True + config.eu_sample = True + # Optimisation hyperparameters - start_step = 0 - n_steps = 1_000_000 - lr = 1e-4 - batch_size = 32 #64 #256 - print_every = 1_000 - opt = optax.adabelief(lr) + config.start_step = 0 + config.n_steps = 1_000_000 + config.lr = 1e-4 + config.batch_size = 32 + config.print_every = 1_000 + config.opt = "adabelief" + config.opt_kwargs = {} + config.num_workers = 8 + # Other - cmap = "gnuplot" - img_dir = img_dir - exp_dir = exp_dir \ No newline at end of file + config.cmap = "gnuplot" + + return config \ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py index 083c9f7..119b6e1 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -3,5 +3,6 @@ from .flowers import flowers from .moons import moons from .grfs import grfs +from .quijote import quijote from .utils import Scaler, ScalerDataset, _InMemoryDataLoader, _TorchDataLoader from ._data import get_labels \ No newline at end of file diff --git a/data/cifar10.py b/data/cifar10.py index 0872860..693b6bc 100644 --- a/data/cifar10.py +++ b/data/cifar10.py @@ -1,10 +1,12 @@ import jax.random as jr +import jax.numpy as jnp +from jaxtyping import Key from torchvision import transforms, datasets from .utils import Scaler, ScalerDataset, _TorchDataLoader -def cifar10(key: jr.PRNGKey) -> ScalerDataset: +def cifar10(key: Key) -> ScalerDataset: key_train, key_valid = jr.split(key) n_pix = 32 # Native resolution for CIFAR10 data_shape = (3, n_pix, n_pix) @@ -49,6 +51,12 @@ def cifar10(key: jr.PRNGKey) -> ScalerDataset: valid_dataloader = _TorchDataLoader( valid_dataset, data_shape, context_shape=None, parameter_dim=parameter_dim, key=key_valid ) + + def label_fn(key, n): + Q = None + A = jr.choice(key, jnp.arange(10), (n,))[:, jnp.newaxis] + return Q, A + return ScalerDataset( name="cifar10", train_dataloader=train_dataloader, @@ -56,5 +64,6 @@ def cifar10(key: jr.PRNGKey) -> ScalerDataset: data_shape=data_shape, parameter_dim=parameter_dim, context_shape=None, - scaler=scaler + scaler=scaler, + label_fn=label_fn ) \ No newline at end of file diff --git a/data/flowers.py b/data/flowers.py index 26c1175..aaa8d10 100644 --- a/data/flowers.py +++ b/data/flowers.py @@ -1,10 +1,12 @@ import jax.random as jr +import jax.numpy as jnp +from jaxtyping import Key from torchvision import transforms, datasets from .utils import Scaler, ScalerDataset, _TorchDataLoader -def flowers(key: jr.PRNGKey, n_pix: int) -> ScalerDataset: +def flowers(key: Key, n_pix: int) -> ScalerDataset: key_train, key_valid = jr.split(key) data_shape = (3, n_pix, n_pix) context_shape = (1,) @@ -46,11 +48,18 @@ def flowers(key: jr.PRNGKey, n_pix: int) -> ScalerDataset: train_dataloader = _TorchDataLoader(train_dataset, key=key_train) valid_dataloader = _TorchDataLoader(valid_dataset, key=key_valid) + + def label_fn(key, n): + Q = None + A = jr.choice(key, jnp.arange(10), (n,))[:, jnp.newaxis] + return Q, A + return ScalerDataset( name="flowers", train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, data_shape=data_shape, context_shape=context_shape, - scaler=scaler + scaler=scaler, + label_fn=label_fn ) diff --git a/data/grfs.py b/data/grfs.py index b1271e2..d4ff6bc 100644 --- a/data/grfs.py +++ b/data/grfs.py @@ -11,6 +11,8 @@ from .utils import Scaler, ScalerDataset, _TorchDataLoader +data_dir = "/project/ls-gruen/users/jed.homer/data/fields/" + def get_fields(key: Key, Q, n_pix: int, n_fields: int): G = np.zeros((n_fields, 1, n_pix, n_pix)) @@ -44,7 +46,6 @@ def get_data(key: Key, n_pix: int) -> Tuple[np.ndarray, np.ndarray]: """ Load Gaussian and lognormal fields """ - data_dir = "/project/ls-gruen/users/jed.homer/data/fields/" if 0: G = np.load(os.path.join(data_dir, f"G_{n_pix=}.npy")) @@ -95,6 +96,12 @@ def __len__(self): return self.tensors[0].size(0) +def get_grf_labels(n_pix: int) -> np.ndarray: + Q = np.load(os.path.join(data_dir, f"G_{n_pix=}.npy")) + A = np.load(os.path.join(data_dir, f"field_parameters_{n_pix=}.npy")) + return Q, A + + def grfs(key, n_pix, split=0.5): key_data, key_train, key_valid = jr.split(key, 3) @@ -144,6 +151,14 @@ def grfs(key, n_pix, split=0.5): # valid_dataloader = _InMemoryDataLoader( # data=X[n_train:], targets=Q[n_train:], key=key_valid # ) + + def label_fn(key, n): + Q, A = get_grf_labels(n_pix) + ix = jr.choice(key, jnp.arange(len(Q)), (n,)) + Q = Q[ix] + A = A[ix] + return Q, A + return ScalerDataset( name="grfs", train_dataloader=train_dataloader, @@ -151,5 +166,6 @@ def grfs(key, n_pix, split=0.5): data_shape=data_shape, context_shape=context_shape, parameter_dim=parameter_dim, - scaler=scaler + scaler=scaler, + label_fn=label_fn ) \ No newline at end of file diff --git a/data/mnist.py b/data/mnist.py index 0882b8b..2150680 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -47,6 +47,12 @@ def mnist(key: Key) -> ScalerDataset: valid_dataloader = _InMemoryDataLoader( valid_data, Q=None, A=valid_targets, key=key_valid ) + + def label_fn(key, n): + Q = None + A = jr.choice(key, jnp.arange(10), (n,))[:, jnp.newaxis] + return Q, A + return ScalerDataset( name="mnist", train_dataloader=train_dataloader, @@ -54,5 +60,6 @@ def mnist(key: Key) -> ScalerDataset: data_shape=data_shape, context_shape=None, parameter_dim=parameter_dim, - scaler=scaler + scaler=scaler, + label_fn=label_fn ) \ No newline at end of file diff --git a/data/quijote.py b/data/quijote.py new file mode 100644 index 0000000..464f789 --- /dev/null +++ b/data/quijote.py @@ -0,0 +1,138 @@ +import os +from typing import Tuple +import jax +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import Key, Array +import numpy as np +import torch +from torchvision import transforms + +from .utils import Scaler, ScalerDataset, _TorchDataLoader, _InMemoryDataLoader + +data_dir = "/project/ls-gruen/users/jed.homer/data/fields/" +quijote_dir = "/project/ls-gruen/users/jed.homer/quijote_pdfs/data/" + + +class MapDataset(torch.utils.data.Dataset): + def __init__(self, tensors, transform=None): + # Tuple of (images, contexts, targets), turn them into tensors + self.tensors = tuple( + torch.as_tensor(tensor) for tensor in tensors + ) + self.transform = transform + assert all( + self.tensors[0].size(0) == tensor.size(0) + for tensor in self.tensors + ) + + def __getitem__(self, index): + x = self.tensors[0][index] # Fields + a = self.tensors[1][index] # Parameters + + if self.transform: + x = self.transform(x) + + return x, a + + def __len__(self): + return self.tensors[0].size(0) + + +def get_data(n_pix: int) -> Tuple[Array, Array]: + """ + Load lognormal, gaussian or Quijote fields + """ + X = np.load(os.path.join(data_dir, "quijote_fields.npy"))[:, np.newaxis, ...] + A = np.load(os.path.join(quijote_dir, "ALL_LATIN_PDFS_PARAMETERS.npy")) + + dx = int(256 / n_pix) + X = X.reshape((-1, 1, n_pix, dx, n_pix, dx)).mean(axis=(3, 5)) + return X, A + + +def get_quijote_labels() -> Array: + """ Get labels only of fields dataset """ + Q = np.load(os.path.join(quijote_dir, "ALL_LATIN_PDFS_PARAMETERS.npy")) + return Q + + +def quijote(key, n_pix, split=0.5): + key_train, key_valid = jr.split(key) + + data_shape = (1, n_pix, n_pix) + context_shape = (1, n_pix, n_pix) + parameter_dim = 5 + + X, A = get_data(n_pix) + + print("Quijote data:", X.shape, A.shape) + + min = X.min() + max = X.max() + X = (X - min) / (max - min) # ... -> [0, 1] + + # min = Q.min() + # max = Q.max() + # Q = (Q - min) / (max - min) # ... -> [0, 1] + + scaler = Scaler() # [0,1] -> [-1,1] + + train_transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.Lambda(scaler.forward) + ] + ) + valid_transform = transforms.Compose( + [transforms.Lambda(scaler.forward)] + ) + + n_train = int(split * len(X)) + train_dataset = MapDataset( + (X[:n_train], A[:n_train]), transform=train_transform + ) + valid_dataset = MapDataset( + (X[n_train:], A[n_train:]), transform=valid_transform + ) + # train_dataloader = _TorchDataLoader( + # train_dataset, + # data_shape=data_shape, + # context_shape=None, + # parameter_dim=parameter_dim, + # key=key_train + # ) + # valid_dataloader = _TorchDataLoader( + # valid_dataset, + # data_shape=data_shape, + # context_shape=None, + # parameter_dim=parameter_dim, + # key=key_valid + # ) + + # Don't have many maps + train_dataloader = _InMemoryDataLoader( + X=X[:n_train], A=A[:n_train], key=key_train + ) + valid_dataloader = _InMemoryDataLoader( + X=X[n_train:], A=A[n_train:], key=key_valid + ) + + def label_fn(key, n): + A = get_quijote_labels() + ix = jr.choice(key, jnp.arange(len(A)), (n,)) + Q = None + A = A[ix] + return Q, A + + return ScalerDataset( + name="quijote", + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + data_shape=data_shape, + context_shape=context_shape, + parameter_dim=parameter_dim, + scaler=scaler, + label_fn=label_fn + ) \ No newline at end of file diff --git a/data/utils.py b/data/utils.py index fa88de0..aa6dc79 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,5 +1,5 @@ import abc -from typing import Tuple, Union, NamedTuple +from typing import Tuple, Union, NamedTuple, Callable from dataclasses import dataclass import jax import jax.numpy as jnp @@ -160,4 +160,5 @@ class ScalerDataset: data_shape: Tuple[int] context_shape: Tuple[int] parameter_dim: int - scaler: Scaler \ No newline at end of file + scaler: Scaler + label_fn: Callable \ No newline at end of file diff --git a/main.py b/main.py index ff20629..8ed71e3 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,4 @@ import os -from copy import deepcopy from typing import Sequence import jax import jax.numpy as jnp @@ -8,10 +7,9 @@ from jaxtyping import Key import optax import numpy as np -from tqdm import trange +import ml_collections import sbgm -from sbgm import utils import data import configs @@ -22,7 +20,7 @@ def get_model( data_shape: Sequence[int], context_shape: Sequence[int], parameter_dim: int, - config: configs.Config + config: ml_collections.ConfigDict ) -> eqx.Module: if model_type == "Mixer": model = sbgm.models.Mixer2d( @@ -36,14 +34,30 @@ def get_model( if model_type == "UNet": model = sbgm.models.UNet( data_shape=data_shape, - **config.model_args, + is_biggan=config.model.is_biggan, + dim_mults=config.model.dim_mults, + hidden_size=config.model.hidden_size, + heads=config.model.heads, + dim_head=config.model.dim_head, + dropout_rate=config.model.dropout_rate, + num_res_blocks=config.model.num_res_blocks, + attn_resolutions=config.model.attn_resolutions, + final_activation=config.model.final_activation, a_dim=parameter_dim, key=model_key ) if model_type == "UNetXY": model = sbgm.models.UNetXY( data_shape=data_shape, - **config.model_args, + is_biggan=config.model.is_biggan, + dim_mults=config.model.dim_mults, + hidden_size=config.model.hidden_size, + heads=config.model.heads, + dim_head=config.model.dim_head, + dropout_rate=config.model.dropout_rate, + num_res_blocks=config.model.num_res_blocks, + attn_resolutions=config.model.attn_resolutions, + final_activation=config.model.final_activation, q_dim=context_shape[0], # Just grab channel assuming 'q' is a map like x a_dim=parameter_dim, key=model_key @@ -61,7 +75,7 @@ def get_model( def get_dataset( - dataset_name: str, key: Key, config: configs.Config + dataset_name: str, key: Key, config: ml_collections.ConfigDict ) -> data.ScalerDataset: if dataset_name == "flowers": dataset = data.flowers(key, n_pix=config.n_pix) @@ -73,10 +87,12 @@ def get_dataset( dataset = data.moons(key) if dataset_name == "grfs": dataset = data.grfs(key, n_pix=config.n_pix) + if dataset_name == "quijote": + dataset = data.quijote(key, n_pix=config.n_pix, split=0.9) return dataset -def get_sde(config: configs.Config) -> sbgm.sde.SDE: +def get_sde(config: ml_collections.ConfigDict) -> sbgm.sde.SDE: name = config.sde + "SDE" assert name in ["VESDE", "VPSDE", "SubVPSDE"] sde = getattr(sbgm.sde, name) @@ -89,177 +105,24 @@ def get_sde(config: configs.Config) -> sbgm.sde.SDE: ) -def get_opt( - config: configs.Config -) -> optax.GradientTransformation: +def get_opt(config: ml_collections.ConfigDict): return getattr(optax, config.opt)(config.lr, **config.opt_kwargs) -def train( - key, - # Diffusion model and SDE - model, - sde, - # Dataset - dataset, - # Experiment config - config, - # Reload optimiser or not - reload_opt_state=False, - # Sharding of devices to run on - sharding=None, - # Location to save model, figs, .etc in - save_dir=None, -): - print(f"Training SGM with {config.sde} SDE on {config.dataset_name} dataset.") - - # Experiment and image save directories - exp_dir, img_dir = utils.make_dirs(save_dir, config) - - # Plot SDE over time - utils.plot_sde(sde, filename=os.path.join(exp_dir, "sde.png")) - - # Plot a sample of training data - utils.plot_train_sample( - dataset, - sample_size=config.sample_size, - cmap=config.cmap, - vs=None, - filename=os.path.join(img_dir, "data.png") - ) - - # Model and optimiser save filenames - model_filename = os.path.join( - exp_dir, f"sgm_{dataset.name}_{config.model_type}.eqx" - ) - state_filename = os.path.join( - exp_dir, f"state_{dataset.name}_{config.model_type}.obj" - ) - - # Reload optimiser and state if so desired - opt = get_opt(config) - if not reload_opt_state: - opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array)) - start_step = 0 - else: - state = utils.load_opt_state(filename=state_filename) - model = utils.load_model(model, model_filename) - - opt, opt_state, start_step = state.values() - - print("Loaded model and optimiser state.") - - train_key, sample_key, valid_key = jr.split(key, 3) - - train_total_value = 0 - valid_total_value = 0 - train_total_size = 0 - valid_total_size = 0 - train_losses = [] - valid_losses = [] - dets = [] - - if config.use_ema: - ema_model = deepcopy(model) - - with trange(start_step, config.n_steps, colour="red") as steps: - for step, train_batch, valid_batch in zip( - steps, - dataset.train_dataloader.loop(config.batch_size), - dataset.valid_dataloader.loop(config.batch_size) - ): - # Train - x, q, a = sbgm.shard.shard_batch(train_batch, sharding) - _Lt, model, train_key, opt_state = sbgm.train.make_step( - model, sde, x, q, a, train_key, opt_state, opt.update - ) - - train_total_value += _Lt.item() - train_total_size += 1 - train_losses.append(train_total_value / train_total_size) - - if config.use_ema: - ema_model = sbgm.apply_ema(ema_model, model) - - # Validate - x, q, a = sbgm.shard.shard_batch(valid_batch, sharding) - _Lv = sbgm.train.evaluate( - ema_model if config.use_ema else model, sde, x, q, a, valid_key - ) - - valid_total_value += _Lv.item() - valid_total_size += 1 - valid_losses.append(valid_total_value / valid_total_size) - - steps.set_postfix( - { - "Lt" : f"{train_losses[-1]:.3E}", - "Lv" : f"{valid_losses[-1]:.3E}" - } - ) - - if (step % config.print_every) == 0 or step == config.n_steps - 1: - # Sample model - key_Q, key_sample = jr.split(sample_key) # Fixed key - sample_keys = jr.split(key_sample, config.sample_size ** 2) - - # Sample random labels or use parameter prior for labels - Q, A = data.get_labels(key_Q, dataset.name, config) - - # EU sampling - if config.eu_sample: - sample_fn = sbgm.sample.get_eu_sample_fn( - ema_model if config.use_ema else model, sde, dataset.data_shape - ) - eu_sample = jax.vmap(sample_fn)(sample_keys, Q, A) - - # ODE sampling - if config.ode_sample: - sample_fn = sbgm.sample.get_ode_sample_fn( - ema_model if config.use_ema else model, sde, dataset.data_shape - ) - ode_sample = jax.vmap(sample_fn)(sample_keys, Q, A) - - # Sample images and plot - utils.plot_model_sample( - eu_sample, - ode_sample, - dataset, - config, - filename=os.path.join(img_dir, f"samples_{step:06d}"), - ) - - # Save model - utils.save_model( - ema_model if config.use_ema else model, model_filename - ) - - # Save optimiser state - utils.save_opt_state( - opt, - opt_state, - i=step, - filename=state_filename - ) - - # Plot losses etc - utils.plot_metrics(train_losses, valid_losses, dets, step, exp_dir) - return model - - def main(): """ Fit a score-based diffusion model. """ config = [ - configs.MNISTConfig, - configs.GRFConfig, - configs.FlowersConfig, - configs.CIFAR10Config - ][0] + configs.mnist_config(), + configs.grfs_config(), + configs.flowers_config(), + configs.cifar10_config(), + configs.quijote_config() + ][-1] - root_dir = "/project/ls-gruen/users/jed.homer/1pt_pdf/little_studies/sgm_lib/sgm/" + root_dir = "/project/ls-gruen/users/jed.homer/1pt_pdf/little_studies/sgm_lib/sbgm/" key = jr.key(config.seed) data_key, model_key, train_key = jr.split(key, 3) @@ -270,11 +133,11 @@ def main(): parameter_dim = dataset.parameter_dim sharding = sbgm.shard.get_sharding() reload_opt_state = False # Restart training or not - + # Diffusion model model = get_model( model_key, - config.model_type, + config.model.model_type, data_shape, context_shape, parameter_dim, @@ -282,10 +145,10 @@ def main(): ) # Stochastic differential equation (SDE) - sde = get_sde(config) + sde = get_sde(config.sde) # Fit model to dataset - model = train( + model = sbgm.train.train( train_key, model, sde, diff --git a/pyproject.toml b/pyproject.toml index cbcf79c..8dd7f6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,20 +21,23 @@ urls = {repository = "https://github.com/homerjed/sbgm"} dependencies = [ 'jax>=0.4.28', - 'equinox', - 'diffrax', - 'optax', + 'equinox>=0.11.5', + 'diffrax>=0.6.0', + 'optax>=0.2.3', + 'ml_collections', 'numpy', 'matplotlib', - 'cloudpickle', 'einops>=0.8.0', - 'jaxtyping', 'torch>=2.0', 'torchvision', + 'cloudpickle', 'tqdm', 'powerbox' ] [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["sbgm/*"] \ No newline at end of file diff --git a/sbgm/_sample.py b/sbgm/_sample.py index 4efbd89..11fb7aa 100644 --- a/sbgm/_sample.py +++ b/sbgm/_sample.py @@ -76,9 +76,11 @@ def marginal(i, val): t = time_steps[i] key_eps = jr.fold_in(key, i) + eps_t = jr.normal(key_eps, data_shape) drift, diffusion = reverse_sde.sde(x, t, q, a) mean_x = x - drift * step_size # mu_x = x + drift * -step + # x = [f(x, t) - g^2(t) * score(x, t, q)] * dt + g(t) * sqrt(dt) * eps_t x = mean_x + diffusion * jnp.sqrt(step_size) * eps_t diff --git a/sbgm/_train.py b/sbgm/_train.py index 8076ff0..bb87863 100644 --- a/sbgm/_train.py +++ b/sbgm/_train.py @@ -1,14 +1,32 @@ +import os +from copy import deepcopy +from dataclasses import dataclass from functools import partial -from typing import Tuple +from typing import Tuple, Optional import jax import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu import equinox as eqx from jaxtyping import PyTree, Key, Array +import ml_collections import optax +from tqdm import trange from ._sde import SDE +from ._sample import get_eu_sample_fn, get_ode_sample_fn +from ._shard import shard_batch +from ._misc import ( + make_dirs, + plot_sde, + plot_train_sample, + plot_model_sample, + plot_metrics, + load_model, + load_opt_state, + save_model, + save_opt_state +) Model = eqx.Module OptState = optax.OptState @@ -107,4 +125,160 @@ def evaluate( ) -> Array: model = eqx.tree_inference(model, True) loss = batch_loss_fn(model, sde, x, q, a, key) - return loss \ No newline at end of file + return loss + + +def get_opt(config: ml_collections.ConfigDict): + return getattr(optax, config.opt)(config.lr, **config.opt_kwargs) + + +def train( + key: Key, + # Diffusion model and SDE + model: eqx.Module, + sde: SDE, + # Dataset + dataset: dataclass, + # Experiment config + config: ml_collections.ConfigDict, + # Reload optimiser or not + reload_opt_state: bool = False, + # Sharding of devices to run on + sharding: Optional[jax.sharding.Sharding] = None, + # Location to save model, figs, .etc in + save_dir: Optional[str] = None, +): + print(f"Training SGM with {config.sde.sde} SDE on {config.dataset_name} dataset.") + + # Experiment and image save directories + exp_dir, img_dir = make_dirs(save_dir, config) + + # Plot SDE over time + plot_sde(sde, filename=os.path.join(exp_dir, "sde.png")) + + # Plot a sample of training data + plot_train_sample( + dataset, + sample_size=config.sample_size, + cmap=config.cmap, + vs=None, + filename=os.path.join(img_dir, "data.png") + ) + + # Model and optimiser save filenames + model_filename = os.path.join( + exp_dir, f"sgm_{dataset.name}_{config.model.model_type}.eqx" + ) + state_filename = os.path.join( + exp_dir, f"state_{dataset.name}_{config.model.model_type}.obj" + ) + + # Reload optimiser and state if so desired + opt = get_opt(config) + if not reload_opt_state: + opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array)) + start_step = 0 + else: + state = load_opt_state(filename=state_filename) + model = load_model(model, model_filename) + + opt, opt_state, start_step = state.values() + + print("Loaded model and optimiser state.") + + train_key, sample_key, valid_key = jr.split(key, 3) + + train_total_value = 0 + valid_total_value = 0 + train_total_size = 0 + valid_total_size = 0 + train_losses = [] + valid_losses = [] + dets = [] + + if config.use_ema: + ema_model = deepcopy(model) + + with trange(start_step, config.n_steps, colour="red") as steps: + for step, train_batch, valid_batch in zip( + steps, + dataset.train_dataloader.loop(config.batch_size), + dataset.valid_dataloader.loop(config.batch_size) + ): + # Train + x, q, a = shard_batch(train_batch, sharding) + _Lt, model, train_key, opt_state = make_step( + model, sde, x, q, a, train_key, opt_state, opt.update + ) + + train_total_value += _Lt.item() + train_total_size += 1 + train_losses.append(train_total_value / train_total_size) + + if config.use_ema: + ema_model = apply_ema(ema_model, model) + + # Validate + x, q, a = shard_batch(valid_batch, sharding) + _Lv = evaluate( + ema_model if config.use_ema else model, sde, x, q, a, valid_key + ) + + valid_total_value += _Lv.item() + valid_total_size += 1 + valid_losses.append(valid_total_value / valid_total_size) + + steps.set_postfix( + { + "Lt" : f"{train_losses[-1]:.3E}", + "Lv" : f"{valid_losses[-1]:.3E}" + } + ) + + if (step % config.print_every) == 0 or step == config.n_steps - 1: + # Sample model + key_Q, key_sample = jr.split(sample_key) # Fixed key + sample_keys = jr.split(key_sample, config.sample_size ** 2) + + # Sample random labels or use parameter prior for labels + Q, A = dataset.label_fn(key_Q, config.sample_size ** 2) + + # EU sampling + if config.eu_sample: + sample_fn = get_eu_sample_fn( + ema_model if config.use_ema else model, sde, dataset.data_shape + ) + eu_sample = jax.vmap(sample_fn)(sample_keys, Q, A) + + # ODE sampling + if config.ode_sample: + sample_fn = get_ode_sample_fn( + ema_model if config.use_ema else model, sde, dataset.data_shape + ) + ode_sample = jax.vmap(sample_fn)(sample_keys, Q, A) + + # Sample images and plot + plot_model_sample( + eu_sample, + ode_sample, + dataset, + config, + filename=os.path.join(img_dir, f"samples_{step:06d}"), + ) + + # Save model + save_model( + ema_model if config.use_ema else model, model_filename + ) + + # Save optimiser state + save_opt_state( + opt, + opt_state, + i=step, + filename=state_filename + ) + + # Plot losses etc + plot_metrics(train_losses, valid_losses, dets, step, exp_dir) + return model \ No newline at end of file diff --git a/sbgm/models/_unet.py b/sbgm/models/_unet.py index 21c6e18..c7db896 100644 --- a/sbgm/models/_unet.py +++ b/sbgm/models/_unet.py @@ -266,7 +266,6 @@ def __init__( key, ): keys = jax.random.split(key, 7) - # del key data_channels, in_height, in_width = data_shape @@ -282,8 +281,9 @@ def __init__( activation=jax.nn.silu, key=keys[0], ) + self.first_conv = eqx.nn.Conv2d( - data_channels + a_dim if a_dim is not None else data_channels, + data_channels + 1 if a_dim is not None else data_channels, hidden_size, kernel_size=3, padding=1,