Skip to content

Commit

Permalink
plot using Fourier series on a fine grid, instead of just showing poi…
Browse files Browse the repository at this point in the history
…nt values
  • Loading branch information
ketch committed Mar 8, 2023
1 parent bdabe3f commit 8076bcc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 13 deletions.
29 changes: 26 additions & 3 deletions PSPython_01-linear-PDEs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
"outputs": [],
"source": [
"xi = np.fft.fftfreq(m)*m*2*np.pi/L # Wavenumber \"grid\"\n",
"xi\n",
"# (this is the order in which numpy's FFT gives the frequencies)"
]
},
Expand All @@ -213,7 +214,31 @@
"source": [
"# Initial data\n",
"u = np.sin(2*x)**2 * (x<-L/4)\n",
"uhat0 = np.fft.fft(u)"
"uhat0 = np.fft.fft(u)\n",
"plt.plot(x,u)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the plot above, we have simply \"connected the dots\", using the values of the function at the grid points. We can obtain a more accurate representation by employing the underlying Fourier series representation of the solution, evaluated on a finer grid:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def spectral_representation(x0,uhat):\n",
" u_fun = lambda y : np.real(np.sum(uhat*np.exp(1j*xi*(y+x0))))/len(uhat)\n",
" u_fun = np.vectorize(u_fun)\n",
" return u_fun\n",
"\n",
"u_spectral = spectral_representation(x[0],uhat0)\n",
"x_fine = np.linspace(x[0],x[-1],1000)\n",
"plt.plot(x_fine,u_spectral(x_fine));"
]
},
{
Expand Down Expand Up @@ -263,8 +288,6 @@
"plt.close()\n",
"\n",
"def plot_frame(i):\n",
" #fig = plt.figure()\n",
" #plt.plot(x,frames[i])\n",
" line.set_data(x,frames[i])\n",
" axes.set_title('t='+str(i*k))\n",
" fig.canvas.draw()\n",
Expand Down
37 changes: 27 additions & 10 deletions PSPython_02-pseudospectral-collocation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@
"metadata": {},
"outputs": [],
"source": [
"def spectral_representation(x0,uhat):\n",
" u_fun = lambda y : np.real(np.sum(uhat*np.exp(1j*xi*(y+x0))))/len(uhat)\n",
" u_fun = np.vectorize(u_fun)\n",
" return u_fun\n",
"\n",
"# Spatial grid\n",
"m=64 # Number of grid points in space\n",
"L = 2 * np.pi # Width of spatial domain\n",
Expand All @@ -185,22 +190,29 @@
"\n",
"# Store solutions in a list for plotting later\n",
"frames = [u.copy()]\n",
"ftframes = [uhat0.copy()]\n",
"\n",
"# Now we solve the problem\n",
"for n in range(1,N+1):\n",
" t = n*k\n",
" uhat = np.exp(-(1.j*xi*a + epsilon*xi**2)*t) * uhat0\n",
" u = np.real(np.fft.ifft(uhat))\n",
" frames.append(u.copy())\n",
" ftframes.append(uhat.copy())\n",
" \n",
"# Set up plotting\n",
"fig = plt.figure(figsize=(9,4)); axes = fig.add_subplot(111)\n",
"line, = axes.plot([],[],lw=3)\n",
"axes.set_xlim((x[0],x[-1])); axes.set_ylim((0.,1.))\n",
"plt.close()\n",
"\n",
"x_fine = np.linspace(x[0],x[-1],1000)\n",
"\n",
"def plot_frame(i):\n",
" line.set_data(x,frames[i])\n",
" uhat = ftframes[i]\n",
" u_spectral = spectral_representation(x[0],uhat)\n",
" line.set_data(x_fine,u_spectral(x_fine));\n",
" #line.set_data(x,frames[i])\n",
" axes.set_title('t='+str(i*k))\n",
"\n",
"# Animate the solution\n",
Expand Down Expand Up @@ -310,6 +322,7 @@
" A = np.diag(a(x))\n",
" M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
" lamda = np.linalg.eigvals(M)\n",
" print(np.max(np.abs(lamda)))\n",
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4),\n",
" gridspec_kw={'width_ratios': [3, 1]})\n",
" ax1.plot(x,a(x)); ax1.set_xlim(x[0],x[-1])\n",
Expand Down Expand Up @@ -338,7 +351,7 @@
"outputs": [],
"source": [
"a = lambda x : 2 + np.sin(x)\n",
"plot_spectrum(a)"
"plot_spectrum(a,m=64)"
]
},
{
Expand Down Expand Up @@ -384,12 +397,14 @@
"D = np.diag(1.j*xi)\n",
"x = np.arange(-m/2,m/2)*(L/m)\n",
"A = np.diag(a(x))\n",
"M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
"#M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
"M = -A@Finv@D@F\n",
"\n",
"# Initial data\n",
"u = np.sin(2*x)**2 * (x<-L/4)\n",
"dx = x[1]-x[0]\n",
"dt = 2.0/m/np.max(np.abs(a(x)))/2.\n",
"dt = 2.0/m/np.max(np.abs(a(x)))\n",
"#dt = 1./86.73416328005729 + 1e-4\n",
"T = 10.\n",
"N = int(np.round(T/dt))\n",
"\n",
Expand All @@ -404,7 +419,7 @@
" t = n*dt\n",
" u_old = u.copy()\n",
" u = u_new.copy()\n",
" u_new = u_old + 2*dt*np.dot(M,u)\n",
" u_new = np.real(u_old + 2*dt*np.dot(M,u))\n",
" if ((n % skip) == 0):\n",
" frames.append(u_new.copy())\n",
" \n",
Expand Down Expand Up @@ -493,7 +508,7 @@
"T = 5.\n",
"N = int(np.round(T/dt))\n",
"\n",
"frames = [u.copy()]\n",
"ftframes = [np.fft.fft(u)]\n",
"skip = N//100\n",
"\n",
"# Start with an explicit Euler step\n",
Expand All @@ -509,9 +524,9 @@
" \n",
" A = np.diag(u)\n",
" M = -np.dot(A,np.dot(Finv,np.dot(D,F)))\n",
" u_new = u_old + 2*dt*np.dot(M,u)\n",
" u_new = np.real(u_old + 2*dt*np.dot(M,u))\n",
" if ((n % skip) == 0):\n",
" frames.append(u_new.copy())\n",
" ftframes.append(np.fft.fft(u_new))\n",
" \n",
"# Set up plotting\n",
"fig, ax1 = plt.subplots(1, 1, figsize=(8,4))\n",
Expand All @@ -521,12 +536,14 @@
"plt.close()\n",
"\n",
"def plot_frame(i):\n",
" line1.set_data(x,frames[i])\n",
" uhat = ftframes[i]\n",
" u_spectral = spectral_representation(x[0],uhat)\n",
" line1.set_data(x_fine,u_spectral(x_fine));\n",
" ax1.set_title('t='+str(i*skip*dt))\n",
"\n",
"# Animate the solution\n",
"anim = matplotlib.animation.FuncAnimation(fig, plot_frame,\n",
" frames=len(frames),\n",
" frames=len(ftframes),\n",
" interval=200)\n",
"\n",
"HTML(anim.to_jshtml())"
Expand Down

0 comments on commit 8076bcc

Please sign in to comment.