-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
172 lines (155 loc) · 5.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from typing import *
import os
import logging
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections as mlc
import optax
import wandb
import char_diffusion as cd
import char_diffusion.configs as configs
from char_diffusion.diffusion import get_schedule
from char_diffusion.utils import *
logger = logging.getLogger(__name__)
def train(config: mlc.ConfigDict):
device_count = jax.local_device_count()
logger.info(f"Devices: {jax.devices()}")
if config.use_wandb:
wandb.finish() # Clear out any previous runs.
wandb.init(
project=config.wandb_project_name,
entity=config.wandb_entity,
name=config.name,
config=flatten_dict(config.to_dict()),
id=config.wandb_id,
)
if config.dataset.name in ['enwik8', 'text8']:
datasets = mahoney_dataset(config.dataset.path)
else:
datasets = text_dataset(config.dataset.path)
dataloaders = {
"train": dataloader(
datasets["train"],
seq_len=config.model.seq_len,
micro_batch_size=config.train.batch_size,
max_steps=config.train.max_steps,
device_count=device_count,
),
"valid": dataloader(
datasets["valid"],
seq_len=config.model.seq_len,
micro_batch_size=config.valid.batch_size,
max_steps=config.train.max_steps,
device_count=device_count,
),
}
train_iter = iter(dataloaders["train"])
valid_iter = iter(dataloaders["valid"])
# TODO: Update CharDiffusion so we don't have to specify `bit_width` twice;
# (once in the unet and once in the diffuser).
key = jax.random.PRNGKey(config.seed)
net = cd.UNet1d(
in_channels=1,
model_channels=config.model.base_channels,
key=key,
bit_width=config.model.bit_width,
num_res_blocks=config.model.num_res_blocks,
num_heads=config.model.num_heads,
num_groups=4,
attn_resolutions=(False, False, True),
channel_mult=(1, 2, 4),
)
optim = optax.chain(
optax.clip_by_global_norm(
config.optim.clip_threshold,
),
optax.adam(
config.optim.lr,
b1=config.optim.adam_beta1,
b2=config.optim.adam_beta2,
eps=1e-8,
)
)
optim_state = optim.init(net)
step_state = 0
if (
config.train.resume
and config.checkpoint_path is not None
and Path(config.checkpoint_path).exists()
):
net, optim_state, step_state = load_state_dict(
path=config.checkpoint_path,
tree=(net, optim_state, step_state)
)
elif config.train.resume and Path(config.output_dir).exists():
net, optim_state, step_state = load_state_dict(
path=os.path.join(config.output_dir, "checkpoint", "latest", "checkpoint.eqx"),
tree=(net, optim_state, step_state)
)
logger.info(f"Network parameter count: ~ {format(count(net), ',')}")
logger.info(f"Starting Step: {step_state}")
logger.info(f"Config:\n{config}")
diffuser = cd.CharDiffusion(
num_steps=config.model.num_steps,
bit_width=config.model.bit_width,
use_self_cond=config.model.use_self_cond,
gamma_schedule=get_schedule(config.model.schedule),
optim=optim,
)
for step in range(step_state, config.train.max_steps):
batch = next(train_iter)
batch = np.expand_dims(batch, 2)
key, next_key = jax.random.split(key)
net, batch_loss, optim_state = diffuser.train_step_pmap(
net, batch, optim_state, next_key
)
# Log training stats.
if step % config.train.log_every == 0:
loss = jnp.mean(batch_loss).item()
wandb.log({"train/loss": loss}, step=step)
info = f"Step: {step}/{config.train.max_steps} | Loss: {loss:.5f}"
logger.info(info)
# Evaluate and log the validation stats.
if step % config.train.eval_every == 0:
key, valid_key = jax.random.split(key)
valid_batch = next(valid_iter)
valid_batch = np.expand_dims(valid_batch, 2)
valid_batch_loss = diffuser.eval_step_pmap(net, valid_batch, valid_key)
valid_loss = np.mean(valid_batch_loss).item()
wandb.log({"valid/loss": valid_loss}, step=step)
save(
path=os.path.join(config.output_dir, "checkpoint", "latest", "checkpoint.eqx"),
tree=(net, optim_state, step)
)
# Generate reconstructions and samples.
if step % config.train.sample_every == 0 and step != 0:
key, gen_key = jax.random.split(key)
num_samples = 8
samples = diffuser.generate(
net,
shape=(num_samples, config.model.bit_width, config.model.max_gen_len),
num_steps=config.model.num_gen_steps,
bit_width=config.model.bit_width,
key=gen_key,
time_delta=config.model.time_delta,
)
samples = samples.squeeze(1).device_buffer.to_py()
sample_log = "\nSamples:\n"
for sample in samples:
sample_log += f"➜ {decode(sample)}\n"
logger.info(sample_log)
if step % config.train.save_every == 0 and step != 0:
save(
path=os.path.join(config.output_dir, "checkpoint", f"step-{step}", "checkpoint.eqx"),
tree=(net, optim_state, step)
)
if __name__ == "__main__":
config = configs.char_diffusion_base_config(
dataset_path="./data/war_and_peace.txt",
id=np.random.randint(0, 1e5),
)
config.wandb_entity = ""
os.makedirs(config.output_dir, exist_ok=True)
init_logger(logger, config.output_dir)
train(config)