Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# JAX Examples

Examples demonstrating HydroGym's JAX backend for GPU-accelerated, fully-differentiable flow control.

## What is the JAX backend?

HydroGym's JAX backend provides pseudo-spectral Navier-Stokes solvers written entirely in JAX. This enables:

- **GPU acceleration** — solvers run on GPU via JAX's XLA compilation
- **Vectorized environments** — run many parallel environments inside a single JIT-compiled training loop (PureJAX-style)
- **End-to-end differentiability** — gradients can flow through the solver for gradient-based control

The JAX environments follow the [gymnax](https://github.com/RobertTLange/gymnax) interface (`reset_env` / `step_env` with explicit `params`) and include wrappers (`VecEnv`, `LogWrapper`, `ClipAction`, `NormalizeVecObservation`, `NormalizeVecReward`) for RL training.

## Directory Structure

```
jax/
├── README.md # This file
└── getting_started/ # START HERE
├── README.md # Detailed guide and comparison table
├── 1_kolmogorov/ # 2D Kolmogorov flow (Re=200)
├── 2_channel/ # 3D turbulent channel flow (Re_tau=180)
└── 3_ppo/ # Pure-JAX PPO training (both environments)
```

## Quick Start

```bash
# Activate the GPU environment
source /home/easybuild/venvs/hydrogym_gpu/bin/activate

# Test Kolmogorov flow (float64, 10 steps)
cd getting_started/1_kolmogorov
./run_kolmogorov_docker.sh

# Test channel flow (float32, 5 steps)
cd getting_started/2_channel
./run_channel_docker.sh

# Train PPO
cd getting_started/3_ppo
./run_ppo_docker.sh --env kolmogorov --total-timesteps 20000
```

## Available Environments

| Environment | Solver | Grid | Action | Observation | Reward | Default dtype |
|---|---|---|---|---|---|---|
| `KolmogorovFlow` | 2D pseudo-spectral | 64×64 | 4 body-force modes | 8×8 velocity probes | -(α·TKE + action penalty) | float64 |
| `ChannelFlowSpectralEnv` | 3D pseudo-spectral | 72×72×72 | 24 wall jets | 8×8×2 near-wall velocities | -WSS (drag) | float32 |

## JIT Compilation

Both environments are JIT-compiled via `jax.jit` in the runner scripts, which compiles the full DNS rollout into a single GPU kernel:

```python
jit_reset = jax.jit(env.reset_env)
jit_step = jax.jit(env.step_env)

obs, state = jit_reset(key, params) # triggers compilation
obs, state, reward, done, info = jit_step(key, state, action, params) # full GPU speed
```

The first call compiles (takes ~1–2 minutes); all subsequent calls run at full GPU speed.

## Floating-Point Precision

| Environment | Recommended | Notes |
|---|---|---|
| `KolmogorovFlow` | `float64` | Pseudo-spectral 2D NS requires fp64 for JIT stability; fp32 may produce NaNs under XLA reordering |
| `ChannelFlowSpectralEnv` | `float32` | Stable at fp32 with JIT; fp64 available but ~2x slower on A100 |

Override via `env_config`:
```python
# Kolmogorov: float64 is the default and required for JIT stability
env = KolmogorovFlow(env_config={"dt": 5e-4}) # smaller dt for fp32 experiments

# Channel: toggle precision
env = ChannelFlowSpectralEnv(env_config={"dtype": "float64"})
```

Or via the bash scripts:
```bash
./run_kolmogorov_docker.sh minimize_tke 100 float32 # float32 (may diverge)
./run_channel_docker.sh drag_reduction 10 float64 # float64
```

## Typical Usage

```python
import jax
import jax.numpy as jnp
from hydrogym.jax.envs.kolmogorov import KolmogorovFlow

jax.config.update("jax_enable_x64", True) # required for Kolmogorov + JIT

env = KolmogorovFlow(env_config={}, flow_config={})
params = env.default_params

jit_reset = jax.jit(env.reset_env)
jit_step = jax.jit(env.step_env)

key = jax.random.PRNGKey(0)
obs, state = jit_reset(key, params)

action = jnp.zeros((params.action_dim,))
obs, state, reward, done, info = jit_step(key, state, action, params)
```

**Note:** The channel flow environment downloads a fully turbulent initial field from Hugging Face Hub (`dynamicslab/HydroGym-environments`) on the first run and caches it at `~/.cache/hydrogym/`.

## Requirements

- JAX with GPU support (`jax[cuda12]` or equivalent)
- `flax`, `optax`, `distrax` for PPO training
- Internet access on first run (channel flow initial field download)
338 changes: 338 additions & 0 deletions examples/jax/getting_started/1_kolmogorov/kolmogorov.ipynb

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions examples/jax/getting_started/1_kolmogorov/run_kolmogorov_docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env bash
#
# Run Kolmogorov flow JAX environment for two control objectives:
#
# Objective 1 -- Minimize TKE (suppress energy bursts)
# reward = -(reward_alpha * TKE + action_penalty), reward_alpha > 0
# The agent is penalized for high turbulent kinetic energy and large
# actions. Setting reward_alpha > 0 drives the flow toward a laminar,
# low-energy state.
#
# Objective 2 -- Maximize TKE (enhance turbulent mixing)
# reward = -(reward_alpha * TKE + action_penalty), reward_alpha < 0
# A negative reward_alpha flips the sign of the TKE term, making the
# reward proportional to TKE. The agent is rewarded for driving the
# flow into a more turbulent regime while being penalized for large
# actions.
#
# In both cases the action penalty term (-sum(|a_i|)) discourages
# unnecessarily large actuations and promotes efficient controllers.
#
# Usage:
# ./run_kolmogorov_docker.sh [mode] [num_steps] [dtype]
#
# ./run_kolmogorov_docker.sh # minimize TKE, 10 steps, float64
# ./run_kolmogorov_docker.sh maximize_tke # maximize TKE
# ./run_kolmogorov_docker.sh no_actuation 500 # baseline, 500 steps
# ./run_kolmogorov_docker.sh minimize_tke 1000 float32 # float32 (fast, may diverge)
#
# Actuation:
# The control input is the amplitude of four sinusoidal body-force modes
# added to the x-momentum equation:
# c(y) = a1*sin(k1*y) + a2*sin(k2*y) + a3*sin(k3*y) + a4*sin(k4*y)
# with wavenumbers k1,k2,k3,k4 = 4,5,6,7 (above the base forcing wavenumber).
# Actions are clipped to [-0.5, 0.5].
#
# Precision:
# float64 (default) -- required for JIT stability; matches non-JIT behavior
# float32 -- faster but may produce NaNs due to solver instability under XLA
#

set -e

module purge
module load Python/3.12.3-GCCcore-13.3.0

# Activate Python environment
source /home/easybuild/venvs/hydrogym_gpu/bin/activate

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
MODE="${1:-minimize_tke}"
NUM_STEPS="${2:-10}"
DTYPE="${3:-float64}"

python "$SCRIPT_DIR/test_kolmogorov_env.py" "$MODE" --num-steps "$NUM_STEPS" --dtype "$DTYPE"
198 changes: 198 additions & 0 deletions examples/jax/getting_started/1_kolmogorov/test_kolmogorov_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
Kolmogorov Flow JAX Environment

Usage:
python test_kolmogorov_env.py [mode] [--num-steps N] [--plot]

Modes:
minimize_tke Suppress energy bursts: reward_alpha=1.0 (default)
maximize_tke Enhance turbulent mixing: reward_alpha=-1.0
no_actuation Baseline: zero action, free turbulence evolution

Options:
--plot Save a vorticity comparison PNG (baseline vs selected mode)
"""

import argparse
import sys

# Must be set before JAX initializes — check sys.argv directly
import jax
import jax.numpy as jnp
import numpy as np

from hydrogym.jax.envs.kolmogorov import KolmogorovFlow

jax.config.update("jax_enable_x64", "float32" not in sys.argv)

# ── Mode definitions ────────────────────────────────────────────────────────

MODE_CONFIGS = {
"minimize_tke": dict(
reward_alpha=1.0,
action=jnp.array([-0.25, -0.03, 0.02, 0.01]),
description=(
"Objective: Minimize TKE (suppress energy bursts)\n"
" reward_alpha = 1.0 -> reward = -(TKE + action_penalty)\n"
" Action: small forcing to damp energy transfer"
),
),
"maximize_tke": dict(
reward_alpha=-1.0,
action=jnp.array([0.25, 0.03, -0.02, -0.01]),
description=(
"Objective: Maximize TKE (enhance turbulent mixing)\n"
" reward_alpha = -1.0 -> reward = TKE - action_penalty\n"
" Action: forcing to drive the flow into a more turbulent regime"
),
),
"no_actuation": dict(
reward_alpha=1.0,
action=None, # filled in after env init (needs action_dim)
description=(
"Baseline: zero actuation (free turbulence evolution)\n"
" reward_alpha = 1.0, action = [0, 0, 0, 0]\n"
" Shows natural energy bursts without control"
),
),
}


# ── Helpers ─────────────────────────────────────────────────────────────────


def to_real(omega_hat):
return np.asarray(jnp.fft.irfftn(omega_hat))


def run_steps(env, params, action, num_steps):
jit_reset = jax.jit(env.reset_env)
jit_step = jax.jit(env.step_env)

key = jax.random.PRNGKey(0)

print("Compiling the environment (this may take a moment)...")
obs, state = jit_reset(key, params)
obs, state, reward, done, info = jit_step(key, state, action, params)
obs.block_until_ready()
print("Compilation finished! Now running at full speed.\n")

key = jax.random.PRNGKey(1)
obs, state = jit_reset(key, params)

rows = []
for i in range(num_steps):
key, subkey = jax.random.split(key)
obs, state, reward, done, info = jit_step(subkey, state, action, params)
rows.append((i, float(info["mean_tke"]), float(reward)))
if done:
break
obs.block_until_ready()
return state, rows


def save_plot(env, params, action_actuated, outfile="kolmogorov_comparison.png"):
import matplotlib.pyplot as plt

zero_action = jnp.zeros((params.action_dim,))

print(" Running baseline (zero action)...")
state_zero, _ = run_steps(env, params, zero_action, num_steps=1)

print(" Running actuated case...")
state_act, _ = run_steps(env, params, action_actuated, num_steps=1)

traj_zero = state_zero.trajectory
traj_act = state_act.trajectory

n_snap = min(4, len(traj_zero))
idxs = np.linspace(0, len(traj_zero) - 1, n_snap, dtype=int)

fig, axes = plt.subplots(n_snap, 2, figsize=(10, 3 * n_snap))
if n_snap == 1:
axes = np.array([axes])

for i, idx in enumerate(idxs):
z = to_real(traj_zero[idx])
a = to_real(traj_act[idx])

im0 = axes[i, 0].imshow(z.T, origin="lower")
axes[i, 0].set_title(f"Zero action (t={idx})")
plt.colorbar(im0, ax=axes[i, 0])

im1 = axes[i, 1].imshow(a.T, origin="lower")
axes[i, 1].set_title(f"Controlled case (t={idx})")
plt.colorbar(im1, ax=axes[i, 1])

plt.tight_layout()
plt.savefig(outfile, dpi=150, bbox_inches="tight")
print(f" Saved: {outfile}")


# ── Main ────────────────────────────────────────────────────────────────────


def main():
parser = argparse.ArgumentParser(description="Kolmogorov flow JAX environment runner")
parser.add_argument(
"mode",
nargs="?",
default="minimize_tke",
choices=list(MODE_CONFIGS),
)
parser.add_argument("--num-steps", type=int, default=10)
parser.add_argument(
"--dt",
type=float,
default=None,
help="DNS timestep (default: 1e-3). Halve to improve fp32 stability.",
)
parser.add_argument(
"--dtype",
default="float64",
choices=["float32", "float64"],
help="Floating-point precision (default: float64 — required for JIT stability)",
)
parser.add_argument(
"--plot",
action="store_true",
help="Save a vorticity comparison PNG (baseline vs selected mode)",
)
args = parser.parse_args()

cfg = MODE_CONFIGS[args.mode]

print("=== Kolmogorov Flow JAX Environment ===")
print(f"Mode: {args.mode}")
print(f"Dtype: {args.dtype}")
print(f"dt: {args.dt if args.dt is not None else '1e-3 (default)'}")
print(f"Steps: {args.num_steps}")
print()
print(cfg["description"])
print()

env_config = {}
if args.dt is not None:
env_config["dt"] = args.dt
env = KolmogorovFlow(env_config=env_config, flow_config={})
params = env.default_params.replace(reward_alpha=cfg["reward_alpha"])

action = cfg["action"]
if action is None:
action = jnp.zeros((params.action_dim,))

if args.plot:
print("Generating vorticity comparison plot...")
outfile = f"kolmogorov_{args.mode}.png"
save_plot(env, params, action, outfile=outfile)
print()

print(f"{'Step':>5} {'mean_TKE':>12} {'reward':>12}")
print("-" * 35)
_, rows = run_steps(env, params, action, args.num_steps)
for step, tke, reward in rows:
print(f"{step:>5} {tke:>12.4f} {reward:>12.4f}")


if __name__ == "__main__":
main()
Loading
Loading