Skip to content

Commit

Permalink
ml_collections configs, quijote, fix UNet parameter conditioning bug
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Sep 7, 2024
1 parent 383e3be commit 8f83d53
Show file tree
Hide file tree
Showing 19 changed files with 665 additions and 387 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
branches:
- main
tags:
- "v0.0.*"
- "0.0.*"

jobs:
build:
Expand Down
23 changes: 10 additions & 13 deletions configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from .quijote import QuijoteConfig
from .mnist import MNISTConfig
from .cifar10 import CIFAR10Config
from .flowers import FlowersConfig
from .quijote import quijote_config
from .mnist import mnist_config
from .cifar10 import cifar10_config
from .flowers import flowers_config
from .moons import MoonsConfig
from .dgdm import DgdmConfig
from .grfs import GRFConfig

Config = QuijoteConfig | MoonsConfig | MNISTConfig | CIFAR10Config | FlowersConfig | GRFConfig
from .grfs import grfs_config

__all__ = [
Config,
QuijoteConfig,
quijote_config,
mnist_config,
cifar10_config,
flowers_config,
MoonsConfig,
MNISTConfig,
CIFAR10Config,
FlowersConfig,
DgdmConfig,
GRFConfig
grfs_config
]
80 changes: 48 additions & 32 deletions configs/cifar10.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
class CIFAR10Config:
seed = 0
import ml_collections


def cifar10_config():
config = ml_collections.ConfigDict()

config.seed = 0

# Data
dataset_name = "cifar10"
config.dataset_name = "cifar10"

# Model
# model_type = "Mixer"
# model_args = dict(
Expand All @@ -18,36 +25,45 @@ class CIFAR10Config:
# mix_hidden_size=512,
# num_blocks=4
# )
model_type: str = "UNet"
model_args = dict(
is_biggan=False,
dim_mults=[1, 2, 4],
hidden_size=128,
heads=4,
dim_head=64,
dropout_rate=0.3,
num_res_blocks=2,
attn_resolutions=[8, 16, 32],
)
use_ema = False
config.model = model = ml_collections.ConfigDict()
model.model_type = "UNet"
model.is_biggan = False
model.dim_mults = [1, 1, 1]
model.hidden_size = 128
model.heads = 4
model.dim_head = 64
model.dropout_rate = 0.3
model.num_res_blocks = 2
model.attn_resolutions = [8, 16, 32]
model.final_activation = None

# SDE
sde = "VP"
t1 = 8.
t0 = 1e-5
dt = 0.1
N = 1000
beta_integral = lambda t: t
config.sde = sde = ml_collections.ConfigDict()
sde.sde = "VP"
sde.t1 = 8.
sde.t0 = 1e-5
sde.dt = 0.1
sde.N = 1000
sde.beta_integral = lambda t: t

# Sampling
sample_size = 5
exact_logp = False
ode_sample = True
eu_sample = True
config.use_ema = False
config.sample_size = 5
config.exact_logp = False
config.ode_sample = True
config.eu_sample = True

# Optimisation hyperparameters
start_step = 0
n_steps = 1_000_000
lr = 1e-4
batch_size = 512 #256 # 256 with UNet
print_every = 1_000
opt = "adabelief"
config.start_step = 0
config.n_steps = 1_000_000
config.lr = 1e-4
config.batch_size = 512 #256 # 256 with UNet
config.print_every = 1_000
config.opt = "adabelief"
config.opt_kwargs = {}
config.num_workers = 8

# Other
cmap = None
config.cmap = None

return config
80 changes: 49 additions & 31 deletions configs/flowers.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,56 @@
class FlowersConfig:
seed = 0
import ml_collections


def flowers_config():
config = ml_collections.ConfigDict()

config.seed = 0

# Data
dataset_name = "flowers"
n_pix = 64
config.dataset_name = "flowers"
config.n_pix = 64

# Model
model_type: str = "UNet"
model_args = dict(
is_biggan=False,
dim_mults=[1, 1, 1],
hidden_size=256,
heads=2,
dim_head=64,
dropout_rate=0.3,
num_res_blocks=2,
attn_resolutions=[8, 32, 64]
)
use_ema = False
config.model = model = ml_collections.ConfigDict()
model.model_type = "UNet"
model.is_biggan = False
model.dim_mults = [1, 1, 1]
model.hidden_size = 256
model.heads = 2
model.dim_head = 64
model.dropout_rate = 0.3
model.num_res_blocks = 2
model.attn_resolutions = [8, 32, 64]
model.final_activation = None

# SDE
t1 = 8.
t0 = 1e-5
dt = 0.1
beta_integral = lambda t: t
config.sde = sde = ml_collections.ConfigDict()
sde.sde = "VP"
sde.t1 = 8.
sde.t0 = 1e-5
sde.dt = 0.1
sde.beta_integral = lambda t: t
sde.N = 1000
# sde: SDE = VPSDE(beta_integral, dt=dt, t0=t0, t1=t1)

# Sampling
sample_size = 5
exact_logp = False
ode_sample = True
eu_sample = True
config.sample_size = 5
config.exact_logp = False
config.ode_sample = True
config.eu_sample = True

# Optimisation hyperparameters
start_step = 0
n_steps = 1_000_000
lr = 1e-4
batch_size = 64 #128 #256
print_every = 1_000
opt = "adabelief"
config.use_ema = False
config.start_step = 0
config.n_steps = 1_000_000
config.lr = 1e-4
config.batch_size = 64 #128 #256
config.print_every = 1_000
config.opt = "adabelief"
config.opt_kwargs = {}
config.num_workers = 8

# Other
cmap = None
config.cmap = None

return config
83 changes: 49 additions & 34 deletions configs/grfs.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,55 @@
class GRFConfig:
seed = 0
import ml_collections


def grfs_config():
config = ml_collections.ConfigDict()

config.seed = 0

# Data
dataset_name = "grfs"
n_pix = 64
config.dataset_name = "grfs"
config.n_pix = 64

# Model
model_type: str = "UNetXY"
model_args = dict(
is_biggan=False,
dim_mults=[1, 1, 1],
hidden_size=32,
heads=4,
dim_head=64,
dropout_rate=0.3,
num_res_blocks=2,
attn_resolutions=[8, 16, 32]
)
config.model = model = ml_collections.ConfigDict()
model.model_type = "UNetXY"
model.is_biggan = False
model.dim_mults = [1, 1, 1]
model.hidden_size = 128
model.heads = 4
model.dim_head = 64
model.dropout_rate = 0.3
model.num_res_blocks = 2
model.attn_resolutions = [8, 16, 32]
model.final_activation = None

# SDE
sde = "VP"
t1 = 8.
t0 = 1e-5
dt = 0.1
beta_integral = lambda t: t
N = 1000
config.sde = sde = ml_collections.ConfigDict()
sde.sde = "VP"
sde.t1 = 8.
sde.t0 = 1e-5
sde.dt = 0.1
sde.N = 1000
sde.beta_integral = lambda t: t

# Sampling
sample_size = 8
exact_logp = False
ode_sample = True
eu_sample = True
use_ema = False
config.use_ema = False
config.sample_size = 5
config.exact_logp = False
config.ode_sample = True
config.eu_sample = True

# Optimisation hyperparameters
start_step = 0
n_steps = 1_000_000
lr = 1e-4
batch_size = 256
print_every = 1_000
opt = "adabelief"
num_workers = 8
config.start_step = 0
config.n_steps = 1_000_000
config.lr = 1e-4
config.batch_size = 256
config.print_every = 1_000
config.opt = "adabelief"
config.opt_kwargs = {}
config.num_workers = 8

# Other
cmap = "coolwarm"
config.cmap = "coolwarm"

return config
Loading

0 comments on commit 8f83d53

Please sign in to comment.