diff --git a/pyproject.toml b/pyproject.toml index 01eb298..28f773e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,4 +47,5 @@ dev = [ "nbval>=0.11.0", "ipdb>=0.13.13", "pre-commit>=4.3.0", + "pyinstrument>=5.1.2", ] diff --git a/tests/benchmarks/README.md b/tests/benchmarks/README.md new file mode 100644 index 0000000..720edda --- /dev/null +++ b/tests/benchmarks/README.md @@ -0,0 +1,65 @@ +# Graph Creation Benchmarks + +This directory contains benchmarking scripts to profile the execution time and performance bottlenecks during graph creation. + +## Requirements + +The benchmarks rely on `pyinstrument` and `matplotlib` to generate call-stack flamegraphs and scaling plots. +Make sure you have installed the development dependencies: + +```bash +uv sync --all-extras --dev +# or specifically +uv add --dev pyinstrument +``` + +## 1. Call Stack Flamegraphs (`graph_creation_flamegraph.py`) + +You can run the script from the root of the project to profile graph creation for a specific archetype and grid size. Because the script uses the `tests` utility module, run it via the Python module syntax. By default, it will open an interactive HTML flamegraph in your browser! + +```bash +uv run python -m tests.benchmarks.graph_creation_flamegraph +``` + +### Options + +- `--N `: Set the size of the input grid ($N \times N$). Default is `425` which produces ~180k points (a roughly 10s baseline for the `keisler` graph). +- `--archetype `: The archetype graph to create. Options are `keisler`, `oskarsson_hierarchical`, and `graphcast`. +- `--console`: Print the profiling hierarchy to the console instead of opening the HTML flamegraph in the browser. +- `--save-flamegraph [FILENAME]`: Saves the interactive HTML flamegraph to disk and exits. If no filename is provided, defaults to `pyinstrument_profile.html`. + +**Examples:** + +Profile the hierarchical archetype with $200 \times 200$ points (opens in browser): +```bash +uv run python -m tests.benchmarks.graph_creation_flamegraph --N 200 --archetype oskarsson_hierarchical +``` + +Save the flamegraph to a custom file without opening a server: +```bash +uv run python -m tests.benchmarks.graph_creation_flamegraph --N 425 --save-flamegraph my_profile.html +``` + +## 2. Runtime Scaling Plot (`graph_creation_scaling.py`) + +This script runs the graph creation process across a range of different grid sizes and plots the execution time versus the number of input nodes. This helps visualize how the algorithm's runtime scales as the coordinate size increases. + +```bash +uv run python -m tests.benchmarks.graph_creation_scaling +``` + +### Options + +- `--min-N `: The minimum grid size N ($N \times N$ nodes). Default: 50 +- `--max-N `: The maximum grid size N ($N \times N$ nodes). Default: 400 +- `--num-steps `: Number of intermediate grid sizes to test between min and max. Default: 8 +- `--archetype `: The archetype graph to create. Options are `keisler`, `oskarsson_hierarchical`, and `graphcast`. +- `--output `: The file path to save the generated plot. Default: `scaling_plot.png` +- `--show`: Opens a matplotlib interactive window to display the plot after benchmarking. + +**Examples:** + +Test scaling from $100 \times 100$ to $500 \times 500$ and open the plot interactively: +```bash +uv run python -m tests.benchmarks.graph_creation_scaling --min-N 100 --max-N 500 --num-steps 10 --show +``` diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/benchmarks/graph_creation_flamegraph.py b/tests/benchmarks/graph_creation_flamegraph.py new file mode 100644 index 0000000..74c4ccb --- /dev/null +++ b/tests/benchmarks/graph_creation_flamegraph.py @@ -0,0 +1,110 @@ +import argparse +import http.server +import os +import socketserver +import tempfile +import threading +import time +import webbrowser + +from pyinstrument import Profiler + +import tests.utils as test_utils +import weather_model_graphs as wmg + + +def main(): + parser = argparse.ArgumentParser( + description="Profile graph creation with pyinstrument." + ) + parser.add_argument( + "--N", + type=int, + default=425, + help="Size of grid (NxN points). Default is 425 (~180k points).", + ) + parser.add_argument( + "--archetype", + type=str, + default="keisler", + choices=["keisler", "oskarsson_hierarchical", "graphcast"], + help="Graph archetype to create.", + ) + parser.add_argument( + "--console", + action="store_true", + help="Print the profile to the console instead of opening a flamegraph in the browser.", + ) + parser.add_argument( + "--save-flamegraph", + type=str, + nargs="?", + const="pyinstrument_profile.html", + help="Save the HTML flamegraph to a file (default: pyinstrument_profile.html).", + ) + + args = parser.parse_args() + + print(f"Generating input coordinates for N={args.N} ({args.N**2} points)...") + xy = test_utils.create_fake_xy(N=args.N) + + # Get the graph creation function dynamically based on the argument + fn_name = f"create_{args.archetype}_graph" + create_fn = getattr(wmg.create.archetype, fn_name) + + print(f"Starting pyinstrument profiling for {fn_name}...") + profiler = Profiler(interval=0.001) # 1ms precision + + # Profile the function + profiler.start() + t0 = time.time() + graph = create_fn(coords=xy) + t1 = time.time() + profiler.stop() + + print(f"Graph creation finished in {t1 - t0:.2f} seconds.") + print(f"Graph has {len(graph.nodes)} nodes and {len(graph.edges)} edges.") + + if args.save_flamegraph: + with open(args.save_flamegraph, "w") as f: + f.write(profiler.output_html()) + print(f"Detailed report saved to '{args.save_flamegraph}'.") + + if args.console: + print("\n--- Profile Output ---") + print(profiler.output_text(unicode=True, color=True)) + elif not args.save_flamegraph: + with tempfile.TemporaryDirectory() as temp_dir: + html_path = os.path.join(temp_dir, "index.html") + with open(html_path, "w") as f: + f.write(profiler.output_html()) + + class Handler(http.server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, directory=temp_dir, **kwargs) + + def log_message(self, format, *args): + pass # suppress noisy server logs + + # Find a free port by binding to port 0 + with socketserver.TCPServer(("127.0.0.1", 0), Handler) as httpd: + port = httpd.server_address[1] + url = f"http://127.0.0.1:{port}" + print(f"\nServing flamegraph at {url}") + print("Press Ctrl+C to shut down the server and exit.") + + # Open the browser in a separate thread so we can start serving immediately + def open_browser(): + time.sleep(0.5) + webbrowser.open(url) + + threading.Thread(target=open_browser, daemon=True).start() + + try: + httpd.serve_forever() + except KeyboardInterrupt: + print("\nShutting down server.") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarks/graph_creation_scaling.py b/tests/benchmarks/graph_creation_scaling.py new file mode 100644 index 0000000..a521b43 --- /dev/null +++ b/tests/benchmarks/graph_creation_scaling.py @@ -0,0 +1,91 @@ +import argparse +import time + +import matplotlib.pyplot as plt +import numpy as np + +import tests.utils as test_utils +import weather_model_graphs as wmg + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark graph creation scaling.") + parser.add_argument( + "--min-N", type=int, default=50, help="Minimum grid size N (NxN nodes)." + ) + parser.add_argument( + "--max-N", type=int, default=400, help="Maximum grid size N (NxN nodes)." + ) + parser.add_argument( + "--num-steps", type=int, default=8, help="Number of intermediate steps." + ) + parser.add_argument( + "--archetype", + type=str, + default="keisler", + choices=["keisler", "oskarsson_hierarchical", "graphcast"], + help="Graph archetype to create.", + ) + parser.add_argument( + "--output", + type=str, + default="scaling_plot.png", + help="Path to save the output plot.", + ) + parser.add_argument( + "--show", action="store_true", help="Show the plot interactively." + ) + + args = parser.parse_args() + + # Generate an array of N values + Ns = np.linspace(args.min_N, args.max_N, args.num_steps, dtype=int) + + fn_name = f"create_{args.archetype}_graph" + create_fn = getattr(wmg.create.archetype, fn_name) + + num_nodes_list = [] + times = [] + + print(f"Benchmarking scaling for {fn_name}...") + for n in Ns: + num_nodes = n * n + print(f"Testing N={n:4d} ({num_nodes:7d} nodes)...", end="", flush=True) + xy = test_utils.create_fake_xy(N=n) + + t0 = time.time() + _ = create_fn(coords=xy) + t1 = time.time() + + duration = t1 - t0 + print(f" {duration:.3f} seconds.") + + num_nodes_list.append(num_nodes) + times.append(duration) + + # Create the plot + plt.figure(figsize=(10, 6)) + plt.plot(num_nodes_list, times, marker="o", linestyle="-", linewidth=2) + + # Add a reference line for linear scaling (O(N)) fitted to the first point + ref_linear = [times[0] * (nodes / num_nodes_list[0]) for nodes in num_nodes_list] + plt.plot( + num_nodes_list, ref_linear, linestyle="--", color="gray", label="O(N) Reference" + ) + + plt.title(f"Graph Creation Scaling: {args.archetype}") + plt.xlabel("Number of Input Grid Nodes (N²)") + plt.ylabel("Execution Time (seconds)") + plt.grid(True, which="both", ls="--", alpha=0.7) + plt.legend() + plt.tight_layout() + + plt.savefig(args.output) + print(f"\nPlot saved to {args.output}") + + if args.show: + plt.show() + + +if __name__ == "__main__": + main()