-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathscript_train_unet_jitter.py
executable file
·122 lines (97 loc) · 3.21 KB
/
script_train_unet_jitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import matplotlib as mpl
import torch
from data_management import Jitter, load_dataset
from networks import IterativeNet, UNet
from operators import TVAnalysis, get_tikhonov_matrix
# --- load configuration -----
import config # isort:skip
# ----- general setup -----
mpl.use("agg")
device = torch.device("cuda:0")
# ----- operators -----
OpA = config.meas_op(config.m, config.n, device=device, **config.meas_params)
OpTV = TVAnalysis(config.n, device=device)
# ----- build linear inverter ------
reg_fac = 2e-2
inverter = torch.nn.Linear(OpA.m, OpA.n, bias=False)
inverter.weight.requires_grad = False
inverter.weight.data = get_tikhonov_matrix(OpA, OpTV, reg_fac)
# ----- network configuration -----
subnet_params = {
"in_channels": 1,
"out_channels": 1,
"drop_factor": 0.0,
"base_features": 64,
}
subnet = UNet
it_net_params = {
"operator": OpA,
"inverter": inverter,
"num_iter": 1,
"lam": 0.0,
"lam_learnable": False,
"final_dc": False,
}
# ----- training setup ------
mse_loss = torch.nn.MSELoss(reduction="sum")
def loss_func(pred, tar):
return mse_loss(pred, tar) / pred.shape[0]
train_phases = 2
train_params = {
"num_epochs": [200, 75],
"batch_size": [40, 40],
"loss_func": loss_func,
"save_path": [
os.path.join(
config.RESULTS_PATH,
"unet_jitter_"
"train_phase_{}".format((i + 1) % (train_phases + 1)),
)
for i in range(train_phases + 1)
],
"save_epochs": 1,
"optimizer": torch.optim.Adam,
"optimizer_params": [
{"lr": 8e-5, "eps": 1e-5, "weight_decay": 5e-3},
{"lr": 5e-5, "eps": 1e-5, "weight_decay": 5e-3},
],
"scheduler": torch.optim.lr_scheduler.StepLR,
"scheduler_params": {"step_size": 1, "gamma": 1.0},
"acc_steps": [1, 200],
"train_transform": Jitter(2e0, 0.0, 1.0),
}
# -----data prep -----
X_train, C_train, Y_train = [
tmp.unsqueeze(-2).to(device)
for tmp in load_dataset(config.set_params["path"], subset="train")
]
X_val, C_val, Y_val = [
tmp.unsqueeze(-2).to(device)
for tmp in load_dataset(config.set_params["path"], subset="val")
]
# ------ save hyperparameters -------
os.makedirs(train_params["save_path"][-1], exist_ok=True)
with open(
os.path.join(train_params["save_path"][-1], "hyperparameters.txt"), "w"
) as file:
for key, value in subnet_params.items():
file.write(key + ": " + str(value) + "\n")
for key, value in it_net_params.items():
file.write(key + ": " + str(value) + "\n")
for key, value in train_params.items():
file.write(key + ": " + str(value) + "\n")
file.write("train_phases" + ": " + str(train_phases) + "\n")
# ------ construct network and train -----
subnet = subnet(**subnet_params).to(device)
it_net = IterativeNet(subnet, **it_net_params).to(device)
for i in range(train_phases):
train_params_cur = {}
for key, value in train_params.items():
train_params_cur[key] = (
value[i] if isinstance(value, (tuple, list)) else value
)
print("Phase {}:".format(i + 1))
for key, value in train_params_cur.items():
print(key + ": " + str(value))
it_net.train_on((Y_train, X_train), (Y_val, X_val), **train_params_cur)