-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlearningtosimulate.py
executable file
·496 lines (422 loc) · 20.6 KB
/
learningtosimulate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import os
import torch
import json
import numpy as np
import torch_geometric as pyg
import matplotlib.pyplot as plt
import networkx as nx
from tqdm import tqdm
import math
# import torch_scatter
import matplotlib.pyplot as plt
from matplotlib import animation
# from IPython.display import HTML
class InteractionNetwork(pyg.nn.MessagePassing):
"""Interaction Network as proposed in this paper:
https://proceedings.neurips.cc/paper/2016/hash/3147da8ab4a0437c15ef51a5cc7f2dc4-Abstract.html"""
def __init__(self, hidden_size, layers):
super().__init__()
self.lin_edge = MLP(hidden_size * 3, hidden_size, hidden_size, layers)
self.lin_node = MLP(hidden_size * 2, hidden_size, hidden_size, layers)
def forward(self, x, edge_index, edge_feature):
edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
edge_out = edge_feature + edge_out
node_out = x + node_out
return node_out, edge_out
def message(self, x_i, x_j, edge_feature):
x = torch.cat((x_i, x_j, edge_feature), dim=-1)
x = self.lin_edge(x)
return x
def aggregate(self, inputs, index, dim_size=None):
out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")
return (inputs, out)
class MLP(torch.nn.Module):
"""Multi-Layer perceptron"""
def __init__(self, input_size, hidden_size, output_size, layers, layernorm=True):
super().__init__()
self.layers = torch.nn.ModuleList()
for i in range(layers):
self.layers.append(torch.nn.Linear(
input_size if i == 0 else hidden_size,
output_size if i == layers - 1 else hidden_size,
))
if i != layers - 1:
self.layers.append(torch.nn.ReLU())
if layernorm:
self.layers.append(torch.nn.LayerNorm(output_size))
self.reset_parameters()
def reset_parameters(self):
for layer in self.layers:
if isinstance(layer, torch.nn.Linear):
layer.weight.data.normal_(0, 1 / math.sqrt(layer.in_features))
layer.bias.data.fill_(0)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class LearnedSimulator(torch.nn.Module):
"""Graph Network-based Simulators(GNS)"""
def __init__(
self,
hidden_size=128,
n_mp_layers=10, # number of GNN layers
num_particle_types=9,
particle_type_dim=16, # embedding dimension of particle types
dim=2, # dimension of the world, typical 2D or 3D
window_size=5, # the model looks into W frames before the frame to be predicted
):
super().__init__()
self.window_size = window_size
self.embed_type = torch.nn.Embedding(num_particle_types, particle_type_dim)
self.node_in = MLP(particle_type_dim + dim * (window_size + 2), hidden_size, hidden_size, 3)
self.edge_in = MLP(dim + 1, hidden_size, hidden_size, 3)
self.node_out = MLP(hidden_size, hidden_size, dim, 3, layernorm=False)
self.n_mp_layers = n_mp_layers
self.layers = torch.nn.ModuleList([InteractionNetwork(
hidden_size, 3
) for _ in range(n_mp_layers)])
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.embed_type.weight)
def forward(self, data):
# pre-processing
# node feature: combine categorial feature data.x and contiguous feature data.pos.
node_feature = torch.cat((self.embed_type(data.x), data.pos), dim=-1)
node_feature = self.node_in(node_feature)
edge_feature = self.edge_in(data.edge_attr)
# stack of GNN layers
for i in range(self.n_mp_layers):
node_feature, edge_feature = self.layers[i](node_feature, data.edge_index, edge_feature=edge_feature)
# post-processing
out = self.node_out(node_feature)
return out
class OneStepDataset(pyg.data.Dataset):
def __init__(self, data_path, split, window_length=7, noise_std=3.0E-4, return_pos=False):
super().__init__()
# load dataset from the disk
with open(os.path.join(data_path, "metadata.json")) as f:
self.metadata = json.load(f)
with open(os.path.join(data_path, f"{split}_offset.json")) as f:
self.offset = json.load(f)
self.offset = {int(k): v for k, v in self.offset.items()}
self.window_length = window_length
self.noise_std = noise_std
self.return_pos = return_pos
self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")
for traj in self.offset.values():
self.dim = traj["position"]["shape"][2]
break
# cut particle trajectories according to time slices
self.windows = []
for traj in self.offset.values():
size = traj["position"]["shape"][1]
length = traj["position"]["shape"][0] - window_length + 1
for i in range(length):
desc = {
"size": size,
"type": traj["particle_type"]["offset"],
"pos": traj["position"]["offset"] + i * size * self.dim,
}
self.windows.append(desc)
def len(self):
return len(self.windows)
def get(self, idx):
# load corresponding data for this time slice
window = self.windows[idx]
size = window["size"]
particle_type = self.particle_type[window["type"]: window["type"] + size].copy()
particle_type = torch.from_numpy(particle_type)
position_seq = self.position[window["pos"]: window["pos"] + self.window_length * size * self.dim].copy()
position_seq.resize(self.window_length, size, self.dim)
position_seq = position_seq.transpose(1, 0, 2)
target_position = position_seq[:, -1]
position_seq = position_seq[:, :-1]
target_position = torch.from_numpy(target_position)
position_seq = torch.from_numpy(position_seq)
# construct the graph
with torch.no_grad():
graph = preprocess(particle_type, position_seq, target_position, self.metadata, self.noise_std)
if self.return_pos:
return graph, position_seq[:, -1]
return graph
def generate_noise(position_seq, noise_std):
"""Generate noise for a trajectory"""
velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]
time_steps = velocity_seq.size(1)
velocity_noise = torch.randn_like(velocity_seq) * (noise_std / time_steps ** 0.5)
velocity_noise = velocity_noise.cumsum(dim=1)
position_noise = velocity_noise.cumsum(dim=1)
position_noise = torch.cat((torch.zeros_like(position_noise)[:, 0:1], position_noise), dim=1)
return position_noise
def preprocess(particle_type, position_seq, target_position, metadata, noise_std):
"""Preprocess a trajectory and construct the graph"""
# apply noise to the trajectory
position_noise = generate_noise(position_seq, noise_std)
position_seq = position_seq + position_noise
# calculate the velocities of particles
recent_position = position_seq[:, -1]
velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]
# construct the graph based on the distances between particles
n_particle = recent_position.size(0)
edge_index = pyg.nn.radius_graph(recent_position, metadata["default_connectivity_radius"], loop=True,
max_num_neighbors=n_particle)
# node-level features: velocity, distance to the boundary
normal_velocity_seq = (velocity_seq - torch.tensor(metadata["vel_mean"])) / torch.sqrt(
torch.tensor(metadata["vel_std"]) ** 2 + noise_std ** 2)
boundary = torch.tensor(metadata["bounds"])
distance_to_lower_boundary = recent_position - boundary[:, 0]
distance_to_upper_boundary = boundary[:, 1] - recent_position
distance_to_boundary = torch.cat((distance_to_lower_boundary, distance_to_upper_boundary), dim=-1)
distance_to_boundary = torch.clip(distance_to_boundary / metadata["default_connectivity_radius"], -1.0, 1.0)
# edge-level features: displacement, distance
dim = recent_position.size(-1)
edge_displacement = (torch.gather(recent_position, dim=0, index=edge_index[0].unsqueeze(-1).expand(-1, dim)) -
torch.gather(recent_position, dim=0, index=edge_index[1].unsqueeze(-1).expand(-1, dim)))
edge_displacement /= metadata["default_connectivity_radius"]
edge_distance = torch.norm(edge_displacement, dim=-1, keepdim=True)
# ground truth for training
if target_position is not None:
last_velocity = velocity_seq[:, -1]
next_velocity = target_position + position_noise[:, -1] - recent_position
acceleration = next_velocity - last_velocity
acceleration = (acceleration - torch.tensor(metadata["acc_mean"])) / torch.sqrt(
torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2)
else:
acceleration = None
# return the graph with features
graph = pyg.data.Data(
x=particle_type,
edge_index=edge_index,
edge_attr=torch.cat((edge_displacement, edge_distance), dim=-1),
y=acceleration,
pos=torch.cat((velocity_seq.reshape(velocity_seq.size(0), -1), distance_to_boundary), dim=-1)
)
return graph
def rollout(model, data, metadata, noise_std):
device = next(model.parameters()).device
model.eval()
window_size = model.window_size + 1
total_time = data["position"].size(0)
traj = data["position"][:window_size]
traj = traj.permute(1, 0, 2)
particle_type = data["particle_type"]
for time in range(total_time - window_size):
with torch.no_grad():
graph = preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0)
graph = graph.to(device)
acceleration = model(graph).cpu()
acceleration = acceleration * torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2) + torch.tensor(metadata["acc_mean"])
recent_position = traj[:, -1]
recent_velocity = recent_position - traj[:, -2]
new_velocity = recent_velocity + acceleration
new_position = recent_position + new_velocity
traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)
return traj
def oneStepMSE(simulator, dataloader, metadata, noise):
"""Returns two values, loss and MSE"""
total_loss = 0.0
total_mse = 0.0
batch_count = 0
simulator.eval()
with torch.no_grad():
scale = torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise ** 2).cuda()
for data in valid_loader:
data = data.cuda()
pred = simulator(data)
mse = ((pred - data.y) * scale) ** 2
mse = mse.sum(dim=-1).mean()
loss = ((pred - data.y) ** 2).mean()
total_mse += mse.item()
total_loss += loss.item()
batch_count += 1
return total_loss / batch_count, total_mse / batch_count
def train(params, simulator, train_loader, valid_loader, valid_rollout_dataset):
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(simulator.parameters(), lr=params["lr"])
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1 ** (1 / 5e6))
# recording loss curve
train_loss_list = []
eval_loss_list = []
onestep_mse_list = []
rollout_mse_list = []
total_step = 0
for i in range(params["epoch"]):
simulator.train()
progress_bar = tqdm(train_loader, desc=f"Epoch {i}")
total_loss = 0
batch_count = 0
for data in progress_bar:
optimizer.zero_grad()
data = data.cuda()
pred = simulator(data)
loss = loss_fn(pred, data.y)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
batch_count += 1
progress_bar.set_postfix({"loss": loss.item(), "avg_loss": total_loss / batch_count, "lr": optimizer.param_groups[0]["lr"]})
total_step += 1
train_loss_list.append((total_step, loss.item()))
# evaluation
if total_step % params["eval_interval"] == 0:
simulator.eval()
eval_loss, onestep_mse = oneStepMSE(simulator, valid_loader, valid_dataset.metadata, params["noise"])
eval_loss_list.append((total_step, eval_loss))
onestep_mse_list.append((total_step, onestep_mse))
tqdm.write(f"\nEval: Loss: {eval_loss}, One Step MSE: {onestep_mse}")
simulator.train()
# do rollout on valid set
if total_step % params["rollout_interval"] == 0:
simulator.eval()
rollout_mse = rolloutMSE(simulator, valid_rollout_dataset, params["noise"])
rollout_mse_list.append((total_step, rollout_mse))
tqdm.write(f"\nEval: Rollout MSE: {rollout_mse}")
simulator.train()
# save model
if total_step % params["save_interval"] == 0:
torch.save(
{
"model": simulator.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
os.path.join(model_path, f"checkpoint_{total_step}.pt")
)
return train_loss_list, eval_loss_list, onestep_mse_list, rollout_mse_list
class RolloutDataset(pyg.data.Dataset):
def __init__(self, data_path, split, window_length=7):
super().__init__()
# load data from the disk
with open(os.path.join(data_path, "metadata.json")) as f:
self.metadata = json.load(f)
with open(os.path.join(data_path, f"{split}_offset.json")) as f:
self.offset = json.load(f)
self.offset = {int(k): v for k, v in self.offset.items()}
self.window_length = window_length
self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")
for traj in self.offset.values():
self.dim = traj["position"]["shape"][2]
break
def len(self):
return len(self.offset)
def get(self, idx):
traj = self.offset[idx]
size = traj["position"]["shape"][1]
time_step = traj["position"]["shape"][0]
particle_type = self.particle_type[
traj["particle_type"]["offset"]: traj["particle_type"]["offset"] + size].copy()
particle_type = torch.from_numpy(particle_type)
position = self.position[
traj["position"]["offset"]: traj["position"]["offset"] + time_step * size * self.dim].copy()
position.resize(traj["position"]["shape"])
position = torch.from_numpy(position)
data = {"particle_type": particle_type, "position": position}
return data
def visualize_pair(particle_type, position_pred, position_gt, metadata):
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
plot_info = [
visualize_prepare(axes[0], particle_type, position_gt, metadata),
visualize_prepare(axes[1], particle_type, position_pred, metadata),
]
axes[0].set_title("Ground truth")
axes[1].set_title("Prediction")
plt.close()
def update(step_i):
outputs = []
for _, position, points in plot_info:
for type_, line in points.items():
mask = particle_type == type_
line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])
outputs.append(line)
return outputs
return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=10, blit=True)
def visualize_prepare(ax, particle_type, position, metadata):
bounds = metadata["bounds"]
ax.set_xlim(bounds[0][0], bounds[0][1])
ax.set_ylim(bounds[1][0], bounds[1][1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect(1.0)
points = {type_: ax.plot([], [], "o", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}
return ax, position, points
if __name__ == "__main__":
print(f"PyTorch has version {torch.__version__} with cuda {torch.version.cuda}")
device = torch.device('cuda')
DATASET_NAME = "WaterDropSmall"
OUTPUT_DIR = os.path.join(DATASET_NAME)
data_path = OUTPUT_DIR
model_path = os.path.join("temp", "models", DATASET_NAME)
rollout_path = os.path.join("temp", "rollouts", DATASET_NAME)
checkpoint = torch.load("temp/models/WaterDropSmall/checkpoint_24000.pt")
simulator = LearnedSimulator()
simulator = simulator.cuda()
# simulator.load_state_dict(checkpoint["model"])
params = {
"epoch": 1,
"batch_size": 4,
"lr": 1e-4,
"noise": 3e-4,
"save_interval": 1000,
"eval_interval": 1000,
"rollout_interval": 200000,
}
dataset_sample = OneStepDataset(OUTPUT_DIR, "valid", return_pos=True)
graph, position = dataset_sample[0]
print(f"The first item in the valid set is a graph: {graph}")
print(f"This graph has {graph.num_nodes} nodes and {graph.num_edges} edges.")
print(f"Each node is a particle and each edge is the interaction between two particles.")
print(
f"Each node has {graph.num_node_features} categorial feature (Data.x), which represents the type of the node.")
print(
f"Each node has a {graph.pos.size(1)}-dim feature vector (Data.pos), which represents the positions and velocities of the particle (node) in several frames.")
print(
f"Each edge has a {graph.num_edge_features}-dim feature vector (Data.edge_attr), which represents the relative distance and displacement between particles.")
print(
f"The model is expected to predict a {graph.y.size(1)}-dim vector for each node (Data.y), which represents the acceleration of the particle.")
# # remove directions of edges, because it is a symmetric directed graph.
# nx_graph = pyg.utils.to_networkx(graph).to_undirected()
# # remove self loops, because every node has a self loop.
# nx_graph.remove_edges_from(nx.selfloop_edges(nx_graph))
# plt.figure(figsize=(7, 7))
# nx.draw(nx_graph, pos={i: tuple(v) for i, v in enumerate(position)}, node_size=50)
# plt.show()
train_dataset = OneStepDataset(data_path, "train", noise_std=params["noise"])
valid_dataset = OneStepDataset(data_path, "valid", noise_std=params["noise"])
train_loader = pyg.loader.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, pin_memory=True,
num_workers=2)
valid_loader = pyg.loader.DataLoader(valid_dataset, batch_size=params["batch_size"], shuffle=False, pin_memory=True,
num_workers=2)
valid_rollout_dataset = RolloutDataset(data_path, "valid")
# simulator = LearnedSimulator()
# simulator = simulator.cuda()
train_loss_list, eval_loss_list, onestep_mse_list, rollout_mse_list = train(params, simulator, train_loader, valid_loader, valid_rollout_dataset)
#
# plt.figure()
# plt.plot(*zip(*train_loss_list), label="train")
# plt.plot(*zip(*eval_loss_list), label="valid")
# plt.xlabel('Iterations')
# plt.ylabel('Loss')
# plt.title('Loss')
# plt.legend()
# plt.show()
print('Create rollout ...')
rollout_dataset = RolloutDataset(data_path, "valid")
simulator.eval()
rollout_data = rollout_dataset[0]
rollout_out = rollout(simulator, rollout_data, rollout_dataset.metadata, params["noise"])
rollout_out = rollout_out.permute(1, 0, 2)
TYPE_TO_COLOR = {
3: "black",
0: "green",
7: "magenta",
6: "gold",
5: "blue",
}
print('Create anim ...')
anim = visualize_pair(rollout_data["particle_type"], rollout_out, rollout_data["position"],
rollout_dataset.metadata)
HTML(anim.to_html5_video())