Skip to content

Commit 8076bcc

Browse files
committed
plot using Fourier series on a fine grid, instead of just showing point values
1 parent bdabe3f commit 8076bcc

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

PSPython_01-linear-PDEs.ipynb

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
"outputs": [],
192192
"source": [
193193
"xi = np.fft.fftfreq(m)*m*2*np.pi/L # Wavenumber \"grid\"\n",
194+
"xi\n",
194195
"# (this is the order in which numpy's FFT gives the frequencies)"
195196
]
196197
},
@@ -213,7 +214,31 @@
213214
"source": [
214215
"# Initial data\n",
215216
"u = np.sin(2*x)**2 * (x<-L/4)\n",
216-
"uhat0 = np.fft.fft(u)"
217+
"uhat0 = np.fft.fft(u)\n",
218+
"plt.plot(x,u)"
219+
]
220+
},
221+
{
222+
"cell_type": "markdown",
223+
"metadata": {},
224+
"source": [
225+
"In the plot above, we have simply \"connected the dots\", using the values of the function at the grid points. We can obtain a more accurate representation by employing the underlying Fourier series representation of the solution, evaluated on a finer grid:"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": null,
231+
"metadata": {},
232+
"outputs": [],
233+
"source": [
234+
"def spectral_representation(x0,uhat):\n",
235+
" u_fun = lambda y : np.real(np.sum(uhat*np.exp(1j*xi*(y+x0))))/len(uhat)\n",
236+
" u_fun = np.vectorize(u_fun)\n",
237+
" return u_fun\n",
238+
"\n",
239+
"u_spectral = spectral_representation(x[0],uhat0)\n",
240+
"x_fine = np.linspace(x[0],x[-1],1000)\n",
241+
"plt.plot(x_fine,u_spectral(x_fine));"
217242
]
218243
},
219244
{
@@ -263,8 +288,6 @@
263288
"plt.close()\n",
264289
"\n",
265290
"def plot_frame(i):\n",
266-
" #fig = plt.figure()\n",
267-
" #plt.plot(x,frames[i])\n",
268291
" line.set_data(x,frames[i])\n",
269292
" axes.set_title('t='+str(i*k))\n",
270293
" fig.canvas.draw()\n",

PSPython_02-pseudospectral-collocation.ipynb

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@
162162
"metadata": {},
163163
"outputs": [],
164164
"source": [
165+
"def spectral_representation(x0,uhat):\n",
166+
" u_fun = lambda y : np.real(np.sum(uhat*np.exp(1j*xi*(y+x0))))/len(uhat)\n",
167+
" u_fun = np.vectorize(u_fun)\n",
168+
" return u_fun\n",
169+
"\n",
165170
"# Spatial grid\n",
166171
"m=64 # Number of grid points in space\n",
167172
"L = 2 * np.pi # Width of spatial domain\n",
@@ -185,22 +190,29 @@
185190
"\n",
186191
"# Store solutions in a list for plotting later\n",
187192
"frames = [u.copy()]\n",
193+
"ftframes = [uhat0.copy()]\n",
188194
"\n",
189195
"# Now we solve the problem\n",
190196
"for n in range(1,N+1):\n",
191197
" t = n*k\n",
192198
" uhat = np.exp(-(1.j*xi*a + epsilon*xi**2)*t) * uhat0\n",
193199
" u = np.real(np.fft.ifft(uhat))\n",
194200
" frames.append(u.copy())\n",
201+
" ftframes.append(uhat.copy())\n",
195202
" \n",
196203
"# Set up plotting\n",
197204
"fig = plt.figure(figsize=(9,4)); axes = fig.add_subplot(111)\n",
198205
"line, = axes.plot([],[],lw=3)\n",
199206
"axes.set_xlim((x[0],x[-1])); axes.set_ylim((0.,1.))\n",
200207
"plt.close()\n",
201208
"\n",
209+
"x_fine = np.linspace(x[0],x[-1],1000)\n",
210+
"\n",
202211
"def plot_frame(i):\n",
203-
" line.set_data(x,frames[i])\n",
212+
" uhat = ftframes[i]\n",
213+
" u_spectral = spectral_representation(x[0],uhat)\n",
214+
" line.set_data(x_fine,u_spectral(x_fine));\n",
215+
" #line.set_data(x,frames[i])\n",
204216
" axes.set_title('t='+str(i*k))\n",
205217
"\n",
206218
"# Animate the solution\n",
@@ -310,6 +322,7 @@
310322
" A = np.diag(a(x))\n",
311323
" M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
312324
" lamda = np.linalg.eigvals(M)\n",
325+
" print(np.max(np.abs(lamda)))\n",
313326
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4),\n",
314327
" gridspec_kw={'width_ratios': [3, 1]})\n",
315328
" ax1.plot(x,a(x)); ax1.set_xlim(x[0],x[-1])\n",
@@ -338,7 +351,7 @@
338351
"outputs": [],
339352
"source": [
340353
"a = lambda x : 2 + np.sin(x)\n",
341-
"plot_spectrum(a)"
354+
"plot_spectrum(a,m=64)"
342355
]
343356
},
344357
{
@@ -384,12 +397,14 @@
384397
"D = np.diag(1.j*xi)\n",
385398
"x = np.arange(-m/2,m/2)*(L/m)\n",
386399
"A = np.diag(a(x))\n",
387-
"M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
400+
"#M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
401+
"M = -A@Finv@D@F\n",
388402
"\n",
389403
"# Initial data\n",
390404
"u = np.sin(2*x)**2 * (x<-L/4)\n",
391405
"dx = x[1]-x[0]\n",
392-
"dt = 2.0/m/np.max(np.abs(a(x)))/2.\n",
406+
"dt = 2.0/m/np.max(np.abs(a(x)))\n",
407+
"#dt = 1./86.73416328005729 + 1e-4\n",
393408
"T = 10.\n",
394409
"N = int(np.round(T/dt))\n",
395410
"\n",
@@ -404,7 +419,7 @@
404419
" t = n*dt\n",
405420
" u_old = u.copy()\n",
406421
" u = u_new.copy()\n",
407-
" u_new = u_old + 2*dt*np.dot(M,u)\n",
422+
" u_new = np.real(u_old + 2*dt*np.dot(M,u))\n",
408423
" if ((n % skip) == 0):\n",
409424
" frames.append(u_new.copy())\n",
410425
" \n",
@@ -493,7 +508,7 @@
493508
"T = 5.\n",
494509
"N = int(np.round(T/dt))\n",
495510
"\n",
496-
"frames = [u.copy()]\n",
511+
"ftframes = [np.fft.fft(u)]\n",
497512
"skip = N//100\n",
498513
"\n",
499514
"# Start with an explicit Euler step\n",
@@ -509,9 +524,9 @@
509524
" \n",
510525
" A = np.diag(u)\n",
511526
" M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
512-
" u_new = u_old + 2*dt*np.dot(M,u)\n",
527+
" u_new = np.real(u_old + 2*dt*np.dot(M,u))\n",
513528
" if ((n % skip) == 0):\n",
514-
" frames.append(u_new.copy())\n",
529+
" ftframes.append(np.fft.fft(u_new))\n",
515530
" \n",
516531
"# Set up plotting\n",
517532
"fig, ax1 = plt.subplots(1, 1, figsize=(8,4))\n",
@@ -521,12 +536,14 @@
521536
"plt.close()\n",
522537
"\n",
523538
"def plot_frame(i):\n",
524-
" line1.set_data(x,frames[i])\n",
539+
" uhat = ftframes[i]\n",
540+
" u_spectral = spectral_representation(x[0],uhat)\n",
541+
" line1.set_data(x_fine,u_spectral(x_fine));\n",
525542
" ax1.set_title('t='+str(i*skip*dt))\n",
526543
"\n",
527544
"# Animate the solution\n",
528545
"anim = matplotlib.animation.FuncAnimation(fig, plot_frame,\n",
529-
" frames=len(frames),\n",
546+
" frames=len(ftframes),\n",
530547
" interval=200)\n",
531548
"\n",
532549
"HTML(anim.to_jshtml())"

0 commit comments

Comments
 (0)