diff --git a/examples/wave_equation/readme.md b/examples/wave_equation/readme.md new file mode 100644 index 0000000..20ec083 --- /dev/null +++ b/examples/wave_equation/readme.md @@ -0,0 +1,16 @@ +# Forward Wave Equation + +Solve the 1D wave equation using IDRLnet: + +$$u_{tt} = c^2 u_{xx}$$ + +on $[0, 1] \times [0, 1]$ with fixed-end boundary conditions, sinusoidal initial +displacement, and zero initial velocity. + +**Exact solution:** $u(x, t) = \sin(\pi x) \cos(\pi c t)$ + +## Usage + +```bash +python wave_equation.py +``` diff --git a/examples/wave_equation/wave_equation.py b/examples/wave_equation/wave_equation.py new file mode 100644 index 0000000..aa0fbf7 --- /dev/null +++ b/examples/wave_equation/wave_equation.py @@ -0,0 +1,133 @@ +"""Forward 1D wave equation solved with IDRLnet. + +PDE: u_tt = c^2 * u_xx on [0, 1] x [0, 1] +BCs: u(0, t) = u(1, t) = 0 (fixed ends) +ICs: u(x, 0) = sin(pi * x), u_t(x, 0) = 0 +Exact solution: u(x, t) = sin(pi * x) * cos(pi * c * t) +""" + +from sympy import Symbol, sin, cos, pi +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.tri as tri +import idrlnet.shortcut as sc + +c = 1.0 # wave speed + +x = Symbol("x") +t_symbol = Symbol("t") +time_range = {t_symbol: (0, 1)} +geo = sc.Line1D(0.0, 1.0) + + +@sc.datanode(name="wave_equation") +def interior_domain(): + points = geo.sample_interior( + 10000, bounds={x: (0.0, 1.0)}, param_ranges=time_range + ) + constraints = {"wave_equation": 0} + return points, constraints + + +@sc.datanode(name="t_init") +def init_domain(): + """Initial displacement: u(x, 0) = sin(pi * x).""" + points = geo.sample_interior(200, param_ranges={t_symbol: 0.0}) + constraints = sc.Variables({"u": sin(pi * x)}) + return points, constraints + + +@sc.datanode(name="t_init_vel") +def init_velocity_domain(): + """Initial velocity: u_t(x, 0) = 0.""" + points = geo.sample_interior(200, param_ranges={t_symbol: 0.0}) + constraints = sc.Variables({"u__t": 0.0}) + return points, constraints + + +@sc.datanode(name="x_boundary") +def boundary_domain(): + """Boundary conditions: u(0, t) = u(1, t) = 0.""" + points = geo.sample_boundary(200, param_ranges=time_range) + constraints = sc.Variables({"u": 0.0}) + return points, constraints + + +net = sc.get_net_node( + inputs=("x", "t"), + outputs=("u",), + name="net1", + arch=sc.Arch.mlp, +) +pde = sc.WaveNode(c=c, dim=1, time=True, u="u") +s = sc.Solver( + sample_domains=( + interior_domain(), + init_domain(), + init_velocity_domain(), + boundary_domain(), + ), + netnodes=[net], + pdes=[pde], + max_iter=5000, +) +s.solve() + +# Inference and plotting +coord = s.infer_step( + { + "wave_equation": ["x", "t", "u"], + "t_init": ["x", "t"], + "x_boundary": ["x", "t"], + } +) +num_x = coord["wave_equation"]["x"].cpu().detach().numpy().ravel() +num_t = coord["wave_equation"]["t"].cpu().detach().numpy().ravel() +num_u = coord["wave_equation"]["u"].cpu().detach().numpy().ravel() + +init_x = coord["t_init"]["x"].cpu().detach().numpy().ravel() +init_t = coord["t_init"]["t"].cpu().detach().numpy().ravel() +boundary_x = coord["x_boundary"]["x"].cpu().detach().numpy().ravel() +boundary_t = coord["x_boundary"]["t"].cpu().detach().numpy().ravel() + +# Exact solution for comparison +u_exact = np.sin(np.pi * num_x) * np.cos(np.pi * c * num_t) + +triang_total = tri.Triangulation(num_t, num_x) + +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +# Predicted solution +ax1 = axes[0] +tcf = ax1.tricontourf(triang_total, num_u, 100, cmap="jet") +plt.colorbar(tcf, ax=ax1) +ax1.set_xlabel("$t$") +ax1.set_ylabel("$x$") +ax1.set_title("Predicted $u(x,t)$") +ax1.scatter(init_t, init_x, c="black", marker="x", s=8) +ax1.scatter(boundary_t, boundary_x, c="black", marker="x", s=8) + +# Exact solution +ax2 = axes[1] +tcf2 = ax2.tricontourf(triang_total, u_exact, 100, cmap="jet") +plt.colorbar(tcf2, ax=ax2) +ax2.set_xlabel("$t$") +ax2.set_ylabel("$x$") +ax2.set_title("Exact $u(x,t)$") + +# Absolute error +ax3 = axes[2] +tcf3 = ax3.tricontourf(triang_total, np.abs(num_u - u_exact), 100, cmap="jet") +plt.colorbar(tcf3, ax=ax3) +ax3.set_xlabel("$t$") +ax3.set_ylabel("$x$") +ax3.set_title("Absolute error") + +plt.tight_layout() +plt.savefig("wave_equation.png", dpi=500, bbox_inches="tight", pad_inches=0.02) +plt.show() +plt.close() + +# Print L2 error +l2_error = np.sqrt(np.mean((num_u - u_exact) ** 2)) +print(f"L2 error: {l2_error:.6f}")