Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ dev = [
"nbval>=0.11.0",
"ipdb>=0.13.13",
"pre-commit>=4.3.0",
"pyinstrument>=5.1.2",
]
65 changes: 65 additions & 0 deletions tests/benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Graph Creation Benchmarks

This directory contains benchmarking scripts to profile the execution time and performance bottlenecks during graph creation.

## Requirements

The benchmarks rely on `pyinstrument` and `matplotlib` to generate call-stack flamegraphs and scaling plots.
Make sure you have installed the development dependencies:

```bash
uv sync --all-extras --dev
# or specifically
uv add --dev pyinstrument
```

## 1. Call Stack Flamegraphs (`graph_creation_flamegraph.py`)

You can run the script from the root of the project to profile graph creation for a specific archetype and grid size. Because the script uses the `tests` utility module, run it via the Python module syntax. By default, it will open an interactive HTML flamegraph in your browser!

```bash
uv run python -m tests.benchmarks.graph_creation_flamegraph
```

### Options

- `--N <int>`: Set the size of the input grid ($N \times N$). Default is `425` which produces ~180k points (a roughly 10s baseline for the `keisler` graph).
- `--archetype <name>`: The archetype graph to create. Options are `keisler`, `oskarsson_hierarchical`, and `graphcast`.
- `--console`: Print the profiling hierarchy to the console instead of opening the HTML flamegraph in the browser.
- `--save-flamegraph [FILENAME]`: Saves the interactive HTML flamegraph to disk and exits. If no filename is provided, defaults to `pyinstrument_profile.html`.

**Examples:**

Profile the hierarchical archetype with $200 \times 200$ points (opens in browser):
```bash
uv run python -m tests.benchmarks.graph_creation_flamegraph --N 200 --archetype oskarsson_hierarchical
```

Save the flamegraph to a custom file without opening a server:
```bash
uv run python -m tests.benchmarks.graph_creation_flamegraph --N 425 --save-flamegraph my_profile.html
```

## 2. Runtime Scaling Plot (`graph_creation_scaling.py`)

This script runs the graph creation process across a range of different grid sizes and plots the execution time versus the number of input nodes. This helps visualize how the algorithm's runtime scales as the coordinate size increases.

```bash
uv run python -m tests.benchmarks.graph_creation_scaling
```

### Options

- `--min-N <int>`: The minimum grid size N ($N \times N$ nodes). Default: 50
- `--max-N <int>`: The maximum grid size N ($N \times N$ nodes). Default: 400
- `--num-steps <int>`: Number of intermediate grid sizes to test between min and max. Default: 8
- `--archetype <name>`: The archetype graph to create. Options are `keisler`, `oskarsson_hierarchical`, and `graphcast`.
- `--output <path>`: The file path to save the generated plot. Default: `scaling_plot.png`
- `--show`: Opens a matplotlib interactive window to display the plot after benchmarking.

**Examples:**

Test scaling from $100 \times 100$ to $500 \times 500$ and open the plot interactively:
```bash
uv run python -m tests.benchmarks.graph_creation_scaling --min-N 100 --max-N 500 --num-steps 10 --show
```
Empty file added tests/benchmarks/__init__.py
Empty file.
110 changes: 110 additions & 0 deletions tests/benchmarks/graph_creation_flamegraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse
import http.server
import os
import socketserver
import tempfile
import threading
import time
import webbrowser

from pyinstrument import Profiler

import tests.utils as test_utils
import weather_model_graphs as wmg


def main():
parser = argparse.ArgumentParser(
description="Profile graph creation with pyinstrument."
)
parser.add_argument(
"--N",
type=int,
default=425,
help="Size of grid (NxN points). Default is 425 (~180k points).",
)
parser.add_argument(
"--archetype",
type=str,
default="keisler",
choices=["keisler", "oskarsson_hierarchical", "graphcast"],
help="Graph archetype to create.",
)
parser.add_argument(
"--console",
action="store_true",
help="Print the profile to the console instead of opening a flamegraph in the browser.",
)
parser.add_argument(
"--save-flamegraph",
type=str,
nargs="?",
const="pyinstrument_profile.html",
help="Save the HTML flamegraph to a file (default: pyinstrument_profile.html).",
)

args = parser.parse_args()

print(f"Generating input coordinates for N={args.N} ({args.N**2} points)...")
xy = test_utils.create_fake_xy(N=args.N)

# Get the graph creation function dynamically based on the argument
fn_name = f"create_{args.archetype}_graph"
create_fn = getattr(wmg.create.archetype, fn_name)

print(f"Starting pyinstrument profiling for {fn_name}...")
profiler = Profiler(interval=0.001) # 1ms precision

# Profile the function
profiler.start()
t0 = time.time()
graph = create_fn(coords=xy)
t1 = time.time()
profiler.stop()

print(f"Graph creation finished in {t1 - t0:.2f} seconds.")
print(f"Graph has {len(graph.nodes)} nodes and {len(graph.edges)} edges.")

if args.save_flamegraph:
with open(args.save_flamegraph, "w") as f:
f.write(profiler.output_html())
print(f"Detailed report saved to '{args.save_flamegraph}'.")

if args.console:
print("\n--- Profile Output ---")
print(profiler.output_text(unicode=True, color=True))
elif not args.save_flamegraph:
with tempfile.TemporaryDirectory() as temp_dir:
html_path = os.path.join(temp_dir, "index.html")
with open(html_path, "w") as f:
f.write(profiler.output_html())

class Handler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, directory=temp_dir, **kwargs)

def log_message(self, format, *args):
pass # suppress noisy server logs

# Find a free port by binding to port 0
with socketserver.TCPServer(("127.0.0.1", 0), Handler) as httpd:
port = httpd.server_address[1]
url = f"http://127.0.0.1:{port}"
print(f"\nServing flamegraph at {url}")
print("Press Ctrl+C to shut down the server and exit.")

# Open the browser in a separate thread so we can start serving immediately
def open_browser():
time.sleep(0.5)
webbrowser.open(url)

threading.Thread(target=open_browser, daemon=True).start()

try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nShutting down server.")


if __name__ == "__main__":
main()
91 changes: 91 additions & 0 deletions tests/benchmarks/graph_creation_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
import time

import matplotlib.pyplot as plt
import numpy as np

import tests.utils as test_utils
import weather_model_graphs as wmg


def main():
parser = argparse.ArgumentParser(description="Benchmark graph creation scaling.")
parser.add_argument(
"--min-N", type=int, default=50, help="Minimum grid size N (NxN nodes)."
)
parser.add_argument(
"--max-N", type=int, default=400, help="Maximum grid size N (NxN nodes)."
)
parser.add_argument(
"--num-steps", type=int, default=8, help="Number of intermediate steps."
)
parser.add_argument(
"--archetype",
type=str,
default="keisler",
choices=["keisler", "oskarsson_hierarchical", "graphcast"],
help="Graph archetype to create.",
)
parser.add_argument(
"--output",
type=str,
default="scaling_plot.png",
help="Path to save the output plot.",
)
parser.add_argument(
"--show", action="store_true", help="Show the plot interactively."
)

args = parser.parse_args()

# Generate an array of N values
Ns = np.linspace(args.min_N, args.max_N, args.num_steps, dtype=int)

fn_name = f"create_{args.archetype}_graph"
create_fn = getattr(wmg.create.archetype, fn_name)

num_nodes_list = []
times = []

print(f"Benchmarking scaling for {fn_name}...")
for n in Ns:
num_nodes = n * n
print(f"Testing N={n:4d} ({num_nodes:7d} nodes)...", end="", flush=True)
xy = test_utils.create_fake_xy(N=n)

t0 = time.time()
_ = create_fn(coords=xy)
t1 = time.time()

duration = t1 - t0
print(f" {duration:.3f} seconds.")

num_nodes_list.append(num_nodes)
times.append(duration)

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(num_nodes_list, times, marker="o", linestyle="-", linewidth=2)

# Add a reference line for linear scaling (O(N)) fitted to the first point
ref_linear = [times[0] * (nodes / num_nodes_list[0]) for nodes in num_nodes_list]
plt.plot(
num_nodes_list, ref_linear, linestyle="--", color="gray", label="O(N) Reference"
)

plt.title(f"Graph Creation Scaling: {args.archetype}")
plt.xlabel("Number of Input Grid Nodes (N²)")
plt.ylabel("Execution Time (seconds)")
plt.grid(True, which="both", ls="--", alpha=0.7)
plt.legend()
plt.tight_layout()

plt.savefig(args.output)
print(f"\nPlot saved to {args.output}")

if args.show:
plt.show()


if __name__ == "__main__":
main()
Loading