-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_recon.py
153 lines (139 loc) · 5.77 KB
/
run_recon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
import sys, pathlib
import pickle
import h5py
import numpy as np
from src.pty_base import ObjectProbeUpdateMode
from src.pty_algs import parse_alg, parse_eps
from src.pty_live import run_rtisi_psi, RealTimeProbeUpdate
from src.pty_data import get_norm_probe
from src.run_utils import get_data, get_git_revision_hash, get_random_runname, save_results
ctype = np.complex64
ftype = np.float32
parser = argparse.ArgumentParser()
parser.add_argument(
'obj_probe_idx', type=int,
help='Index of object and probe (from our simulated dataset) to use. '
'Should be between 0 and 29, or between 0 and 89 when using --full.')
parser.add_argument(
'iters', type=int,
help='Number of algorithm iterations per scan index')
parser.add_argument(
'buffersize', type=int,
help='Number of exit waves in the buffer')
parser.add_argument(
'--runname', type=str, required=False,
help='Name of this run. Should be unique, will throw error if results already exist. '
'get_random_runname() is called if this is not passed.')
parser.add_argument(
'--alg', type=str,
help='Algorithm to run for live reconstruction. Check parse_alg function from `algs_psi.py`. '
'Examples: "dm1,1e-12" or "Fhybrid_8Xdm1,1e-12+2Xer1e-12".')
parser.add_argument(
'--commit-eps', type=str, required=True,
help='Epsilon to use for object/probe update after each exit wave commit.')
parser.add_argument(
'--commit-pq', type=float, required=True,
help='Probe quantile clip to use for probe update after each exit wave commit. '
'0.95 is a good default. Pass 1 to turn off.')
parser.add_argument(
'--phi0-idxs', nargs='*', default=[-1],
help='Buffer-relative indices to apply Phi0 estimation to. [-1] by default (like in RTISI/RTISI-LA).')
parser.add_argument(
'--scandens', type=int, choices=[10, 15, 20, 30], default=10,
help='Archimedes Spiral scan density (in px). 10 by default.')
parser.add_argument(
'--Ifac', type=float, default=1,
help='Global factor to apply to simulated intensities. Useful to set epsilons relative to. 1 by default.')
parser.add_argument(
'--poisson-lambda-max', type=float, default=1e9,
help='Expected max. intensity for all patterns for Poisson noise simulation. 1e9 by default.')
parser.add_argument(
'--use-gt-probe', action='store_true',
help='Use the ground-truth probe for reconstruction. Overrides --central-from-file if passed.')
parser.add_argument(
'--rt-probe-update', type=RealTimeProbeUpdate.from_string, choices=list(RealTimeProbeUpdate),
default=str(RealTimeProbeUpdate.EACH_ITERATION),
help='Mode for real-time probe update. "each_iteration" by default.')
parser.add_argument(
'--central-from-file', type=str,
help='Load central reconstruction from HDF5 file.')
parser.add_argument(
'--kmax', type=int,
help="Pass to limit number of total exit waves processed, for debugging etc.")
parser.add_argument(
'--full', action='store_true',
help="Use full dataset (40+40+10 object/probe pairs) instead of 'smol' dataset (10+10+10, default).")
parser.add_argument(
'--naive-phase-init', type=str, choices=['zero', 'uniform'], default='zero',
help='Naive phase initialization for exit waves. Recommended default is zero.')
args = parser.parse_args()
# get index of object&probe pair
data_idx = args.obj_probe_idx
# Initialize output path / check if already exists
runname = args.runname if args.runname is not None else get_random_runname(8)
outpath = pathlib.Path(f'results/{runname}/')
result_file = outpath / f'{data_idx}.h5'
print(f"Using runname: {runname}")
if result_file.exists():
print(f"runname {runname} already has a written output file for index {data_idx}! Exiting.")
sys.exit(1)
outpath.mkdir(parents=True, exist_ok=True)
# save meta information (passed args, git commit) for future reference
with open(outpath / f'{data_idx}_meta.pkl', "wb") as output_file:
pickle.dump({'type': 'rtisi', 'args': args, 'commit': get_git_revision_hash()}, output_file)
# Get data
Ak, rk, Psik0, O_gt, P, stft_gt = get_data(
data_idx, args.scandens,
lamb=args.poisson_lambda_max, Ifac=args.Ifac,
naive_phase_init=args.naive_phase_init, ctype=ctype, ftype=ftype
)
# The buffersize determines how much information we have access to in the first step, so get it here already
B = args.buffersize
# O/P/Psi initialization, depending on options
if args.use_gt_probe:
P0 = P
O0 = None
A_alg = Ak[:B]
elif args.central_from_file is not None:
print(f"Loading central reconstruction from HDF5 file {args.central_from_file}")
with h5py.File(args.central_from_file, 'r') as f:
Ocent = f[f'{data_idx}_O_central'][:]
Pcent = f[f'{data_idx}_P_central'][:]
Psicent = f[f'{data_idx}_Psi_central'][:]
P0 = Pcent
O0 = Ocent
Psik0[:Psicent.shape[0]] = Psicent[:]
A_alg = Ak[:Psicent.shape[0]]
else:
# uninformed initialization
A_alg = Ak[:B]
P0_rand = get_norm_probe(np.random.randn(*P.shape) + 1j*np.random.randn(*P.shape), A_alg)
P0 = P0_rand
O0 = None
# Initialize real-time algorithm
eps_commit = parse_eps(args.commit_eps, A_alg)
pq_commit = args.commit_pq
alg = parse_alg(args.alg, A_alg)
print(f"Using algorithm: {alg}")
# Run reconstruction!
Of, Pf, Psif, _, _, _ = run_rtisi_psi(
Ak, rk, O_gt.shape, P0,
B=B, alg=alg, iters=args.iters,
phi0_est_idxs=args.phi0_idxs,
naive_phase_init=args.naive_phase_init,
probe_update=args.rt_probe_update,
object_probe_update_mode=ObjectProbeUpdateMode.PROBE_THEN_OBJECT,
pA_before_commit=True,
track_interm=None,
eps_commit=eps_commit,
pq_commit=pq_commit,
Psik0=Psik0,
O0=O0,
kmax=args.kmax,
)
# Save results
print(f"Saving results to {outpath}")
P_gt = P
save_results(outpath, data_idx, Of, Pf, Psif, O_gt, P_gt, Ak, rk, stft_gt)
print("Done!")