-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
240 lines (182 loc) · 10.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from projected_diffusion.scoremodel import Model, AnnealedLangevinDynamic
from projected_diffusion import scoremodel, utils
import torch
import os
import argparse
import shutil
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def main():
# Parse args
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=['train', 'sample'], default='sample', help="Training or sampling")
# Setting parameters
parser.add_argument("--experiment", type=str, choices=['microstructures', 'trajectories', 'motion', 'human', 'other'], default='microstructures', help="Experiment setting")
parser.add_argument("--projection_path", type=str, default=None, help="If experiment argument is \'other\', set path to custom projection operator")
parser.add_argument("--model_path", type=str, default=None, help="Set path to diffusion model checkpoint")
parser.add_argument("--train_set_path", type=str, default=None, help="Set path to training data if in training mode")
parser.add_argument("--val_set_path", type=str, default=None, help="Set path to validation data if in training mode")
parser.add_argument("--n_samples", type=int, default=1, help="Number of samples")
# Model parameters
parser.add_argument("--eps", type=float, default=1.5e-5, help="Epsilon of step size")
parser.add_argument("--sigma_min", type=float, default=0.005, help="Sigma min of Langevin dynamic")
parser.add_argument("--sigma_max", type=float, default=10., help="Sigma max of Langevin dynamic")
parser.add_argument("--n_steps", type=int, default=10, help="Langevin steps")
parser.add_argument("--annealed_step", type=int, default=25, help="Annealed steps")
# Training parameters
parser.add_argument("--total_iteration", type=int, default=3000, help="Total training iterations")
parser.add_argument("--display_iteration", type=int, default=150, help="Logging frequency")
parser.add_argument("--run_name", type=str, default='train', help="Run name for logging and saving")
# Projection parameters
parser.add_argument("--porosity", type=float, default=0.25, help="Porosity for microstructure projection [leave blank if not running \'microstructures\' sampling]")
parser.add_argument("--gravity", type=float, default=9.8, help="Gravity for motion projection [leave blank if not running \'motion\' sampling]")
parser.add_argument("--obstacles", type=list, default=[], help="Obstacles for trajectories projection [leave blank if not running \'trajectories\' sampling]")
args = parser.parse_args()
if args.mode == 'train':
_train(args)
else:
_sample(args)
def _train(args):
# Load model
channels = 4 if args.experiment == 'trajectories' else 1
model = Model(device, args.n_steps, args.sigma_min, args.sigma_max, channels=channels)
# Resume training if applicable
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.train()
dynamic = AnnealedLangevinDynamic(args.sigma_min,args.sigma_max, args.n_steps, args.annealed_step, model, None, device, eps= args.eps)
dynamic.channels = channels
if args.experiment == 'trajectories': dynamic.img_size = 8
optim = torch.optim.Adam(model.parameters(), lr = 0.005)
# Generate dataset
if args.experiment != 'other':
create_outputs_directory(f"examples/{args.experiment}/models")
create_outputs_directory(os.path.join(f"examples/{args.experiment}/models", args.run_name))
models_dir = f"examples/{args.experiment}/models"
else:
create_outputs_directory(f"examples/{args.projection_path}/models")
create_outputs_directory(os.path.join(f"examples/{args.projection_path}/models", args.run_name))
models_dir = f"examples/{args.projection_path}/models"
if args.experiment == 'microstructures':
from examples.microstructures.dataloader import get_dataset
train_loader, validation_loader = get_dataset(args.train_set_path, args.val_set_path)
elif args.experiment == 'trajectories':
# raise NotImplementedError('Support will be added in future releases.')
from examples.trajectories.dataloader import get_dataset
train_loader, validation_loader = get_dataset(args.train_set_path, os.path.join(f"examples/{args.experiment}/models", args.run_name), device)
elif args.experiment == 'motion':
from examples.motion.dataloader import get_dataset
train_loader, validation_loader = get_dataset(args.train_set_path, args.val_set_path)
elif args.experiment == 'human':
raise NotImplementedError('Support will be added in future releases.')
else:
module_path = f"examples.{args.projection_path}.dataloader"
projection_module = __import__(module_path, fromlist=["get_dataset"])
get_dataset = getattr(projection_module, "get_dataset")
train_loader, validation_loader = get_dataset(args.train_set_path, args.val_set_path)
# Set up trainer
current_iteration = 0
best_val_loss = float('inf')
create_outputs_directory(exist_ok=False)
# Set up logging
losses = scoremodel.AverageMeter('Loss', ':.4f')
progress = scoremodel.ProgressMeter(args.total_iteration, [losses], prefix='Iteration ')
while current_iteration != args.total_iteration:
## Training Routine ##
model.train()
for data, _ in train_loader:
data = data.to(device)
loss = model.loss_fn(data)
optim.zero_grad()
loss.backward()
optim.step()
losses.update(loss.item())
progress.display(current_iteration)
current_iteration += 1
## Validation Routine ##
model.eval()
val_loss_accumulator = 0.0
val_steps = 0
with torch.no_grad():
for data, _ in validation_loader:
data = data.to(device)
val_loss = model.loss_fn(data)
val_loss_accumulator += val_loss.item()
val_steps += 1
# Compute average validation loss for the epoch
avg_validation_loss = val_loss_accumulator / val_steps
# Checkpointing
if avg_validation_loss < best_val_loss:
best_val_loss = avg_validation_loss
# Save original model checkpoint
model_save_path = os.path.join(models_dir, args.run_name, f"ckpt.pt")
torch.save(model.state_dict(), model_save_path)
# Optionally save the optimizer state
optimizer_save_path = os.path.join(models_dir, args.run_name, f"optim.pt")
torch.save(optim.state_dict(), optimizer_save_path)
## Logging ##
if current_iteration % args.display_iteration == 0:
# Save original model checkpoint
model_save_path = os.path.join(models_dir, args.run_name, f"ckpt_{current_iteration}.pt")
torch.save(model.state_dict(), model_save_path)
dynamic = scoremodel.AnnealedLangevinDynamic(args.sigma_min, args.sigma_max, args.n_steps, args.annealed_step, model, None, device, eps=args.eps)
dynamic.channels = channels
if args.experiment == 'trajectories': dynamic.img_size = 8
sample = dynamic.sampling(args.n_samples, only_final=True)
if args.experiment != 'trajectories':
for i in range(len(sample)):
save_path = f'outputs/sample_{i}_step_{current_iteration}.png'
utils.save_images(sample[i], save_path)
else:
for i in range(len(sample)):
save_path = f'outputs/sample_{i}_step_{current_iteration}.pt'
torch.save(sample[i], save_path)
def _sample(args):
# Load projection operator $P_{\matcal{C}}
if args.experiment == 'microstructures':
from examples.microstructures.projection import Projection
# TODO: Move this logic to projection file
pc_u = Projection(k=args.porosity+0.025, threshold=0.0, lower_bound=False)
pc_l = Projection(k=args.porosity-0.025, threshold=0.0, lower_bound=True)
projector = lambda x : pc_l.apply(pc_u.apply(x))
elif args.experiment == 'trajectories':
raise NotImplementedError('Support will be added in future releases.')
from examples.trajectories.projection import Projection
elif args.experiment == 'motion':
from examples.motion.projection import Projection
projector = lambda x : Projection(acceleration=args.gravity).apply(x)
elif args.experiment == 'human':
raise NotImplementedError('Support will be added in future releases.')
else:
module_path = f"examples.{args.projection_path}.projection"
projection_module = __import__(module_path, fromlist=["Projection"])
Projection = getattr(projection_module, "Projection")
projector = lambda x : Projection().apply(x)
# Load model
channels = 4 if args.experiment == 'trajectories' else 1
model = Model(device, args.n_steps, args.sigma_min, args.sigma_max, channels=channels)
if args.model_path is not None:
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
dynamic = AnnealedLangevinDynamic(args.sigma_min,args.sigma_max, args.n_steps, args.annealed_step, model, projector, device, eps= args.eps)
dynamic.channels = channels
if args.experiment == 'trajectories': dynamic.img_size = 8
# Sampling function
def generate_images(n_samples, only_final=True):
sample = dynamic.sampling(n_samples, only_final)
return sample
create_outputs_directory(exist_ok=False)
# Generate and save samples
with torch.no_grad():
sample = generate_images(args.n_samples)
for i in range(len(sample)):
save_path = f'outputs/sample_{i}.png'
utils.save_images(sample[i], save_path)
def create_outputs_directory(path="outputs", exist_ok=True):
# Remove the directory if it exists
if os.path.exists(path) and not exist_ok:
shutil.rmtree(path)
# Create a new empty directory
os.makedirs(path, exist_ok=exist_ok)
print(f"Directory '{path}' has been created.")
if __name__=='__main__':
main()