-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathppdd.py
More file actions
380 lines (343 loc) · 15.7 KB
/
ppdd.py
File metadata and controls
380 lines (343 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import sys
import os
import numpy as np
from scipy import signal
import cunwrap
import abel
from matplotlib import pyplot as plt, patches
from PIL import Image
import fittools
class PPDD(object):
"""
Python Plasma Density Diagnostics(PPDD) is the main class to read input data, perform abel transfrom and output transfrom result.
A run should consists of methods in the following order:
readfile
find_peaks
filt_move
find_symmetry_axis
abel
"""
def __init__(self, xmin = 0, xmax = 800, ymin = 0, ymax = 600, xband = 0.01, yband = 0.1, symin = 0, symax = 0, crop=50, method = 'hansenlaw', \
wavelength=800, scale=1, n0=1, gfactor=1, peak_threshold=99.99, **kwargs):
self.guess = fittools.Guess(**kwargs);
self.xmin = xmin #crop the region [ymin:ymax, xmin:xmax] from raw input data
self.xmax = xmax
self.ymin = ymin
self.ymax = ymax
self.xband = xband #half passbands for filter in x and y direction
self.yband = yband
self.symin = symin #limits the symmetry axis finding to [symin, symax]
self.symax = symax
self.crop = crop #crop the left and right edge of phase plot to avoid edge effect
self.method = method
self.wavelength = wavelength #nm
self.scale = scale #um per pixel
self.n0 = n0
self.gfactor = gfactor
self.peak_threshold = peak_threshold
self.learning = True
self.manual = False
self.peak_fitted = False
self.abel_methods = {
"hansenlaw": self.abel_hansenlaw,
"onion_bordas": self.abel_onion_bordas,
"basex": self.abel_basex
}
def reset(self):
self.peak_fitted = False
def readfile(self, filename):
"""
Read input file filename
"""
_, ext = os.path.splitext(filename)
if ext == 'txt' :
self.rawdata = np.loadtxt(filename, dtype=int)
else :
self.rawdata = np.array(Image.open(filename))
self.peak_fitted = False
def find_peaks(self):
"""
Find the three peaks in the frequency spectrum. This procedure includes cropping the appropriate region.
"""
if not self.peak_fitted : #if already fitted peaks, skip to speed up
#loaded new file or region has changed
self.xy2d = self.rawdata[self.ymin:self.ymax, self.xmin:self.xmax]
#create the shifted amplitude spectrum to fit
XYf2d = np.fft.fftn(self.xy2d)
self.XYf2d_shifted = np.abs(np.fft.fftshift(XYf2d)) #shift frequency of (0,0) to the center
#if manual, skip find_peaks
if self.manual :
return
#if learning, try hot start
if self.learning :
try :
self.find_peaks_hot_start()
return
except RuntimeError :
#hot start failed, fall back to cold start
pass
self.find_peaks_cold_start() #if this fail, a RuntimeError will be raised
def find_peaks_hot_start(self):
"""
Find the three peaks in the frequency spectrum.
"""
#perc = np.percentile(self.XYf2d_shifted, self.peak_threshold)
#self.fx, self.fy, newguess = fittools.find_peaks(self.XYf2d_shifted.clip(min=perc), self.guess)
self.fx, self.fy, newguess = fittools.find_peaks(self.XYf2d_shifted, self.guess)
self.guess = newguess
#print(newguess.peak_ratio, newguess.sigma_x0, newguess.sigma_x1, newguess.sigma_y0, newguess.sigma_y1, newguess.offset_ratio)
self.peak_fitted = True
def find_peaks_cold_start(self):
"""
A cold start without relying on provided initial guess to fit the three peaks. This might be slow.
"""
length_x = self.XYf2d_shifted.shape[1]
length_y = self.XYf2d_shifted.shape[0]
dXf = 1/length_x
dYf = 1/length_y
y_x0 = self.XYf2d_shifted[:, length_x//2] #the x center line, gg if the main peak isn't there!
popt_y = fittools.fit_gaussian(y_x0, -np.inf, np.inf) #fit the 1D center line to get a good starting point for sigam
#get the y sigam
sigma_y0 = popt_y[2]
x_sum = np.sum(self.XYf2d_shifted, axis = 0) #the projection on x axis
peaks = np.array(signal.find_peaks_cwt(x_sum, np.arange(0.01*length_x, 0.1*length_x))) #TODO further polish wavelet coefficients
acceptance = 0.1 #the left and right peaks should be symmetric, this is the maximum accepted difference ratio
center_index = fittools.find_nearest(peaks, length_x//2)
left_peak = center_index - 1
right_peak = center_index + 1
while 1 :
if ((right_peak >= peaks.shape[0]) or (left_peak < 0)) :
break #no more peaks in the left or right
left_dist = peaks[center_index] - peaks[left_peak]
right_dist = peaks[right_peak] - peaks[center_index]
if left_dist > right_dist*(1+acceptance) :
right_peak += 1
continue
if right_dist > left_dist*(1+acceptance) :
left_peak -= 1
continue
#peaks are in acceptance
dist = (left_dist + right_dist) / 2
mean_peak = (x_sum[peaks[left_peak]] + x_sum[peaks[right_peak]]) / 2
try :
popt_x = fittools.find_peaks_1d(x_sum, x_sum[peaks[center_index]], peaks[center_index], sigma_y0, mean_peak, peaks[center_index] + dist, sigma_y0, 0)
break #fit successful
except RuntimeError :
#1D fit failed, change to other peaks and try again
left_peak -= 1
right_peak += 1
continue
if not np.any(popt_x) :
#tough luck, find_peaks_cwt doen't return a pair of left and right peaks in accepetance
#deprerate try
self.guess.sigma_x0 = sigma_y0
self.guess.sigma_y0 = sigma_y0
self.guess.sigma_x1 = sigma_y0
self.guess.sigma_y1 = sigma_y0
self.find_peaks_hot_start() #if fail, a RuntimeError will be raised
return
#We have popt_x and popt_y now
#popt_x: a0, x0, sigma_x0, a1, x1, sigma_x1, offset
coldguess = fittools.Guess(popt_x[3]/popt_x[0], popt_x[2], sigma_y0, popt_x[5], sigma_y0, 0, fx = (popt_x[4]-popt_x[1])*dXf, fy = 0)
self.guess = coldguess
self.find_peaks_hot_start() #this shouldn't fail. Were it to fail, a RuntimeError will be raised
def filt_move(self):
"""
Filt the image(2d-array) xy2d and move the second peak to center. fx and fy is generated by find_peaks. xband and yband are the passband halfwidth of filter on x and y direction respectively. Return the phase spectrum.
"""
length_x = self.xy2d.shape[1]
length_y = self.xy2d.shape[0]
x_filter_length = (length_x//3+1)//2*2-1 #must be odd
y_filter_length = (length_y//3+1)//2*2-1 #must be odd
#Filter on x direction
b = signal.firwin(x_filter_length, cutoff=[self.fx*2-self.xband, self.fx*2+self.xband], window=('kaiser',8), pass_zero=False)
a = np.zeros([x_filter_length])
a[0] = 1
xy2df = signal.filtfilt(b, a, self.xy2d)
#Filter on y direction
b = signal.firwin(y_filter_length, cutoff=self.fy+self.yband, window=('kaiser',8), pass_zero=True)
a = np.zeros([y_filter_length])
a[0] = 1
xy2df = signal.filtfilt(b, a, xy2df, axis = 0)
#Remove negative frequencies
XYf2df = np.fft.fftn(xy2df)
XYf2df[:,length_x//2:]=0
#Shift second peak to center
xy2df0 = np.fft.ifftn(XYf2df)
phase = np.angle(xy2df0)
shifter_x = np.arange(length_x)
phase += shifter_x*(-2*np.pi*self.fx)
shifter_y = np.arange(length_y)[:,np.newaxis]
phase += shifter_y*(-2*np.pi*self.fy)
#Unwrap
phase = (phase+np.pi) % (2*np.pi) - np.pi
self.phase = cunwrap.unwrap(phase)[:, self.crop:-self.crop]
def find_symmetry_axis(self):
if self.symin == self.symax:
if self.symin == 0 :
self.ycenter = fittools.find_symmetry_axis(self.phase, 0, self.ymax-self.ymin)
return
else :
self.ycenter = self.symin
return
self.ycenter = fittools.find_symmetry_axis(self.phase, self.symin, self.symax)
def abel(self):
IM = fittools.half_image(self.phase.transpose(), self.ycenter)
self.abel_methods[self.method](IM)
self.AIM = self.AIM*(self.wavelength/1e3)/(2*np.pi*self.scale) * self.gfactor
#delta n should always be positive, so make sure the majority of image is positive
if (self.AIM>0).sum() < (self.AIM<0).sum() :
self.AIM = -self.AIM
self.AIM += self.n0
self.AIM.clip(min=1)
nc = 1.11943771e27/(self.wavelength**2) # cm^-3
self.AIM = (1-1/self.AIM**2)*nc
def abel_hansenlaw(self, IM):
self.AIM = abel.hansenlaw.hansenlaw_transform(IM, direction = 'inverse').transpose()
def abel_onion_bordas(self, IM):
self.AIM = abel.onion_bordas.onion_bordas_transform(IM, direction = 'inverse').transpose()
def abel_basex(self, IM):
basex_path = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), 'basex')
os.makedirs(basex_path, exist_ok=True)
self.AIM = abel.basex.basex_transform(IM, basis_dir = basex_path, direction='inverse').transpose()
def plot_raw(self, ax, region = None):
"""
region: tuple of (xmin, xmax, ymin, ymax). region not equal None will add a rectagular to display selected region.
"""
ax.set_title('Raw data')
ax.pcolormesh(self.rawdata)
ax.set_xlim(0, self.rawdata.shape[1])
ax.set_ylim(0, self.rawdata.shape[0])
if region :
rect = patches.Rectangle((region[0], region[2]), region[1]-region[0], region[3]-region[2], linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
def plot_amplitude(self, ax, vmax=0, bands = None):
"""
bands: tuple of (xband, yband). bands not equal None will add a rectagular to display passbands.
"""
ax.set_title('Amplitude spectrum')
xfreq = np.fft.fftshift(np.fft.fftfreq(self.XYf2d_shifted.shape[1]))
yfreq = np.fft.fftshift(np.fft.fftfreq(self.XYf2d_shifted.shape[0]))
if vmax==0 :
vmax = np.percentile(self.XYf2d_shifted, self.peak_threshold)
#print("vmax=", vmax)
ax.pcolormesh(xfreq, yfreq, self.XYf2d_shifted, vmax=vmax)
ax.set_xlim(-0.1,0.1)
ax.set_ylim(-0.2,0.2)
if bands :
rect = patches.Rectangle((self.fx-bands[0], -(abs(self.fy)+bands[1])), 2*bands[0], 2*(bands[1]+abs(self.fy)), linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
def plot_phase(self, ax, cax, limits = None, symmetry = None):
"""
limits: tuple of (symin, symax). limits not equal None will add two lineouts of the limitation on symmetry axis finding.
symmetry: y index of symmetry axis. symmetry not equal None will add the symmetry axis.
"""
ax.set_title('Phase spectrum')
im = ax.pcolormesh(self.phase)
plt.colorbar(im, cax)
if limits :
ax.hlines(limits[0], 0, self.phase.shape[1], linewidth=2, colors='r')
ax.hlines(limits[1], 0, self.phase.shape[1], linewidth=2, colors='r')
if symmetry :
ax.hlines(symmetry, 0, self.phase.shape[1], linewidth=3, colors='black')
def plot_density(self, ax, cax, vmin=0, vmax=0):
"""
Plot the result density of abel transform.
"""
ax.set_title('Relative Refractivity')
if vmin == vmax :
vmin = 0
vmax = np.percentile(self.AIM, 99)
im = ax.pcolormesh(self.AIM, vmin=vmin, vmax=vmax)
plt.colorbar(im, cax)
ax.set_xlim(0, self.AIM.shape[1])
ax.set_ylim(0, self.AIM.shape[0])
#def main():
# #create a PPDD object
# pypdd = PPDD()
# failed_reads = []
# failed_peaks = []
# failed_symmetries = []
#
# #Read Data
# for filename in sys.argv[1:]:
# try :
# pypdd.readfile(filename)
# except :
# failed_reads.append(filename)
# continue
# #Fit three peaks to find the secondary peak
# try :
# pypdd.find_peaks()
# except RuntimeError :
# failed_peaks.append(filename)
# continue
# #Filter
# pypdd.filt_move()
# #Find the center of phase spectrum
# try :
# pypdd.find_symmetry_axis()
# except RuntimeError : #currently not possible because find_symmetry_axis always give a center in [ymin, ymax]
# failed_symmetries.append(filename)
# continue
# #Abel transform
# try :
# pypdd.abel()
# except ValueError : #given invalid symmetry axis
# failed_symmetries.append(filename)
# continue
#
# #Plot
# plt.close("all")
# f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20,10))
#
# ax1.set_title('Raw data')
# im1 = ax1.pcolormesh(pypdd.rawdata)
# rect1 = patches.Rectangle((pypdd.xmin, pypdd.ymin), pypdd.xmax-pypdd.xmin, pypdd.ymax-pypdd.ymin, linewidth=2, edgecolor='r', facecolor='none')
# ax1.set_xlim(0, pypdd.rawdata.shape[1])
# ax1.set_ylim(0, 800)
# ax1.add_patch(rect1)
#
# ax2.set_title('Phase spectrum')
# im2 = ax2.pcolormesh(pypdd.phase)
# ax2.hlines(pypdd.ycenter, 0, pypdd.phase.shape[1], linewidth=3, colors='black')
# divider2 = make_axes_locatable(ax2)
# cax2 = divider2.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im2, cax2)
#
# ax3.set_title('Amplitude spectrum')
# XYf2d_shifted = pypdd.XYf2d_shifted
# im3 = ax3.pcolormesh(np.fft.fftshift(np.fft.fftfreq(XYf2d_shifted.shape[1])), np.fft.fftshift(np.fft.fftfreq(XYf2d_shifted.shape[0])), XYf2d_shifted,vmax=1e6)
# ax3.set_xlim(-0.2,0.2)
# ax3.set_ylim(-0.2,0.2)
# rect3 = patches.Rectangle((pypdd.fx-pypdd.xband,-np.abs(pypdd.fy)-pypdd.yband), 2*pypdd.xband, 2*(pypdd.yband+np.abs(pypdd.fy)), linewidth=2, edgecolor='r', facecolor='none')
# ax3.add_patch(rect3)
#
# ax4.set_title('Relative Refractivity')
# im4 = ax4.pcolormesh(pypdd.AIM, vmax=0.1, vmin=0)
# divider4 = make_axes_locatable(ax4)
# cax4 = divider4.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im4, cax4)
#
# outputpath = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), 'output')
# os.makedirs(outputpath, exist_ok=True)
# plt.savefig(os.path.join(outputpath, os.path.basename(filename).rsplit('.', 1)[0]+'.png'), bbox_inches='tight')
# plt.close()
#
# if failed_reads:
# print("Failed to read these input files:", file=sys.stderr)
# for i in failed_reads:
# print(i, file=sys.stderr)
# if failed_peaks:
# print("Failed to find the secondary peak in these input files:", file=sys.stderr)
# for i in failed_peaks:
# print(i, file=sys.stderr)
# if failed_symmetries:
# print("Failed to find the symmetry axis of phase spectrum in these input files:", file=sys.stderr)
# for i in failed_symmetries:
# print(i, file=sys.stderr)
#
# return 0
#
#if __name__ == "__main__":
# main()