Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/wave_equation/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Forward Wave Equation

Solve the 1D wave equation using IDRLnet:

$$u_{tt} = c^2 u_{xx}$$

on $[0, 1] \times [0, 1]$ with fixed-end boundary conditions, sinusoidal initial
displacement, and zero initial velocity.

**Exact solution:** $u(x, t) = \sin(\pi x) \cos(\pi c t)$

## Usage

```bash
python wave_equation.py
```
133 changes: 133 additions & 0 deletions examples/wave_equation/wave_equation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Forward 1D wave equation solved with IDRLnet.

PDE: u_tt = c^2 * u_xx on [0, 1] x [0, 1]
BCs: u(0, t) = u(1, t) = 0 (fixed ends)
ICs: u(x, 0) = sin(pi * x), u_t(x, 0) = 0
Exact solution: u(x, t) = sin(pi * x) * cos(pi * c * t)
"""

from sympy import Symbol, sin, cos, pi
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cos is imported from sympy but not used anywhere in this script. Consider removing it to avoid unused-import lint failures and keep dependencies minimal.

Suggested change
from sympy import Symbol, sin, cos, pi
from sympy import Symbol, sin, pi

Copilot uses AI. Check for mistakes.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import idrlnet.shortcut as sc

c = 1.0 # wave speed

x = Symbol("x")
t_symbol = Symbol("t")
time_range = {t_symbol: (0, 1)}
geo = sc.Line1D(0.0, 1.0)


@sc.datanode(name="wave_equation")
def interior_domain():
points = geo.sample_interior(
10000, bounds={x: (0.0, 1.0)}, param_ranges=time_range
)
constraints = {"wave_equation": 0}
return points, constraints


@sc.datanode(name="t_init")
def init_domain():
"""Initial displacement: u(x, 0) = sin(pi * x)."""
points = geo.sample_interior(200, param_ranges={t_symbol: 0.0})
constraints = sc.Variables({"u": sin(pi * x)})
return points, constraints


@sc.datanode(name="t_init_vel")
def init_velocity_domain():
"""Initial velocity: u_t(x, 0) = 0."""
points = geo.sample_interior(200, param_ranges={t_symbol: 0.0})
constraints = sc.Variables({"u__t": 0.0})
return points, constraints


@sc.datanode(name="x_boundary")
def boundary_domain():
"""Boundary conditions: u(0, t) = u(1, t) = 0."""
points = geo.sample_boundary(200, param_ranges=time_range)
constraints = sc.Variables({"u": 0.0})
return points, constraints


net = sc.get_net_node(
inputs=("x", "t"),
outputs=("u",),
name="net1",
arch=sc.Arch.mlp,
)
pde = sc.WaveNode(c=c, dim=1, time=True, u="u")
s = sc.Solver(
sample_domains=(
interior_domain(),
init_domain(),
init_velocity_domain(),
boundary_domain(),
),
netnodes=[net],
pdes=[pde],
max_iter=5000,
)
s.solve()

# Inference and plotting
coord = s.infer_step(
{
"wave_equation": ["x", "t", "u"],
"t_init": ["x", "t"],
"x_boundary": ["x", "t"],
}
)
num_x = coord["wave_equation"]["x"].cpu().detach().numpy().ravel()
num_t = coord["wave_equation"]["t"].cpu().detach().numpy().ravel()
num_u = coord["wave_equation"]["u"].cpu().detach().numpy().ravel()

init_x = coord["t_init"]["x"].cpu().detach().numpy().ravel()
init_t = coord["t_init"]["t"].cpu().detach().numpy().ravel()
boundary_x = coord["x_boundary"]["x"].cpu().detach().numpy().ravel()
boundary_t = coord["x_boundary"]["t"].cpu().detach().numpy().ravel()

# Exact solution for comparison
u_exact = np.sin(np.pi * num_x) * np.cos(np.pi * c * num_t)

triang_total = tri.Triangulation(num_t, num_x)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Predicted solution
ax1 = axes[0]
tcf = ax1.tricontourf(triang_total, num_u, 100, cmap="jet")
plt.colorbar(tcf, ax=ax1)
ax1.set_xlabel("$t$")
ax1.set_ylabel("$x$")
ax1.set_title("Predicted $u(x,t)$")
ax1.scatter(init_t, init_x, c="black", marker="x", s=8)
ax1.scatter(boundary_t, boundary_x, c="black", marker="x", s=8)

# Exact solution
ax2 = axes[1]
tcf2 = ax2.tricontourf(triang_total, u_exact, 100, cmap="jet")
plt.colorbar(tcf2, ax=ax2)
ax2.set_xlabel("$t$")
ax2.set_ylabel("$x$")
ax2.set_title("Exact $u(x,t)$")

# Absolute error
ax3 = axes[2]
tcf3 = ax3.tricontourf(triang_total, np.abs(num_u - u_exact), 100, cmap="jet")
plt.colorbar(tcf3, ax=ax3)
ax3.set_xlabel("$t$")
ax3.set_ylabel("$x$")
ax3.set_title("Absolute error")

plt.tight_layout()
plt.savefig("wave_equation.png", dpi=500, bbox_inches="tight", pad_inches=0.02)
plt.show()
plt.close()

# Print L2 error
l2_error = np.sqrt(np.mean((num_u - u_exact) ** 2))
print(f"L2 error: {l2_error:.6f}")
Loading