From 380ebd91081ee8cb4598030feb0b88ae942bbc5d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Mar 2026 15:30:45 +0000 Subject: [PATCH 1/8] Add benchmark framework with registry pattern and SHT benchmarks Introduce a torch_harmonics.benchmark subpackage with: - Timer infrastructure (CUDATimer, NullTimer, CPUEventPair) for GPU event-based and CPU wall-clock timing - BenchmarkABC base class with registry via @register_benchmark - CLI runner (python -m torch_harmonics.benchmark) that saves JSON results - RealSHT and InverseRealSHT benchmarks at 1-degree resolution Also add benchmark_results to .gitignore. Co-Authored-By: Claude Opus 4.6 --- .gitignore | 3 +- torch_harmonics/benchmark/__init__.py | 15 ++ torch_harmonics/benchmark/__main__.py | 6 + torch_harmonics/benchmark/benchmark.py | 93 +++++++++++ torch_harmonics/benchmark/run.py | 128 ++++++++++++++++ torch_harmonics/benchmark/sht.py | 83 ++++++++++ torch_harmonics/benchmark/timer.py | 203 +++++++++++++++++++++++++ 7 files changed, 530 insertions(+), 1 deletion(-) create mode 100644 torch_harmonics/benchmark/__init__.py create mode 100644 torch_harmonics/benchmark/__main__.py create mode 100644 torch_harmonics/benchmark/benchmark.py create mode 100644 torch_harmonics/benchmark/run.py create mode 100644 torch_harmonics/benchmark/sht.py create mode 100644 torch_harmonics/benchmark/timer.py diff --git a/.gitignore b/.gitignore index eef5e947..47d28e99 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.DS_Store __pycache__ *.so -checkpoints \ No newline at end of file +checkpoints +*benchmark_results diff --git a/torch_harmonics/benchmark/__init__.py b/torch_harmonics/benchmark/__init__.py new file mode 100644 index 00000000..1baa3686 --- /dev/null +++ b/torch_harmonics/benchmark/__init__.py @@ -0,0 +1,15 @@ +from torch_harmonics.benchmark.benchmark import ( + BenchmarkABC, + BenchmarkResult, + get_benchmarks, + register_benchmark, +) +from torch_harmonics.benchmark.timer import ( + CUDATimer, + NullTimer, + Timer, + TimerResult, +) + +# Import to trigger registration of built-in benchmarks. +import torch_harmonics.benchmark.sht # noqa: F401 diff --git a/torch_harmonics/benchmark/__main__.py b/torch_harmonics/benchmark/__main__.py new file mode 100644 index 00000000..84c9213e --- /dev/null +++ b/torch_harmonics/benchmark/__main__.py @@ -0,0 +1,6 @@ +import sys + +import torch_harmonics.benchmark # noqa: F401 — triggers benchmark registration +from torch_harmonics.benchmark.run import cli + +sys.exit(cli()) diff --git a/torch_harmonics/benchmark/benchmark.py b/torch_harmonics/benchmark/benchmark.py new file mode 100644 index 00000000..00ddf1bb --- /dev/null +++ b/torch_harmonics/benchmark/benchmark.py @@ -0,0 +1,93 @@ +import abc +import dataclasses +from collections.abc import Callable +from typing import Self + +import torch + +from torch_harmonics.benchmark.timer import ( + CPUEventPair, + CUDATimer, + NullTimer, + Timer, + TimerResult, +) + +TensorDict = dict[str, torch.Tensor] + + +@dataclasses.dataclass +class BenchmarkResult: + timer: TimerResult + cpu_time: float + + def __repr__(self) -> str: + return f"BenchmarkResult(timer={self.timer}, cpu_time={self.cpu_time})" + + def asdict(self) -> dict: + return dataclasses.asdict(self) + + def get_logs(self, max_depth: int) -> dict[str, float]: + logs = {"cpu_time": self.cpu_time} + logs.update(self.timer.get_logs(max_depth=max_depth)) + return logs + + +class BenchmarkABC(abc.ABC): + @classmethod + @abc.abstractmethod + def new(cls: type[Self]) -> Self: + """ + Initialize any state needed for the benchmark. + This will be called once before the benchmark is run. + """ + pass + + @classmethod + def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot run benchmark.") + null_timer = NullTimer() + benchmark = cls.new() + for _ in range(warmup): + benchmark.run_instance(null_timer) + timer = CUDATimer() + cpu_timer = CPUEventPair() + cpu_timer.record_start() + for _ in range(iters): + with timer: + benchmark.run_instance(timer) + torch.cuda.synchronize() + cpu_timer.record_end() + return BenchmarkResult( + timer=timer.result, + cpu_time=cpu_timer.elapsed_time_ms(), + ) + + @abc.abstractmethod + def run_instance(self: Self, timer: Timer) -> TensorDict: + """ + Run the benchmark. This will be called multiple times, + and should return a TensorDict of results. + + This must not mutate any state on self, since the same instance may be + used across multiple iterations. + """ + pass + + +_BENCHMARKS: dict[str, type[BenchmarkABC]] = {} + + +def register_benchmark(name: str) -> Callable[[type[BenchmarkABC]], type[BenchmarkABC]]: + def _register(fn: type[BenchmarkABC]) -> type[BenchmarkABC]: + if name in _BENCHMARKS: + raise ValueError(f"Benchmark with name '{name}' is already registered.") + _BENCHMARKS[name] = fn + return fn + + return _register + + +def get_benchmarks() -> dict[str, type[BenchmarkABC]]: + return _BENCHMARKS.copy() diff --git a/torch_harmonics/benchmark/run.py b/torch_harmonics/benchmark/run.py new file mode 100644 index 00000000..a3d8f067 --- /dev/null +++ b/torch_harmonics/benchmark/run.py @@ -0,0 +1,128 @@ +import argparse +import dataclasses +import json +import logging +import pathlib +import subprocess +import sys + +import torch + +from torch_harmonics.benchmark.benchmark import get_benchmarks + +_GIT_COMMIT: str | None = None + + +def get_git_commit() -> str: + global _GIT_COMMIT + if _GIT_COMMIT is None: + try: + commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + dirty = ( + subprocess.check_output( + ["git", "status", "--porcelain"], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + if dirty: + commit = f"{commit}-dirty" + except (subprocess.CalledProcessError, FileNotFoundError): + commit = "unknown" + _GIT_COMMIT = commit + return _GIT_COMMIT + + +def get_device_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_properties(0).name + else: + return "CPU" + + +def main( + benchmark_name: str | None, + iters: int, + output_dir: pathlib.Path, +) -> int: + output_dir.mkdir(parents=True, exist_ok=True) + device_name = get_device_name() + safe_device_name = device_name.replace(" ", "_").replace("/", "_").lower() + + logging.info(f"Running benchmarks on device: {device_name}") + benchmarks = get_benchmarks() + if benchmark_name is not None: + if benchmark_name not in benchmarks: + logging.error( + f"Specified benchmark {benchmark_name} not found. " + f"Available benchmarks: {', '.join(benchmarks.keys())}" + ) + return 1 + benchmarks_to_run = {benchmark_name: benchmarks[benchmark_name]} + else: + benchmarks_to_run = benchmarks + + def get_filename(name, extension) -> pathlib.Path: + safe_name = name.replace("/", "_").replace(".", "_").lower() + return ( + output_dir + / f"{safe_name}_{safe_device_name}_{get_git_commit()}.{extension}" + ) + + for name, cls in benchmarks_to_run.items(): + logging.info(f"Running benchmark: {name}") + result = cls.run_benchmark(iters=iters) + result_data = json.dumps(dataclasses.asdict(result), indent=2) + logging.info(f"Result: {result_data}") + json_path = get_filename(name, "json") + with open(json_path, "w") as f: + logging.info(f"Saving result json to {f.name}") + f.write(result_data) + + return 0 + + +def cli() -> int: + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) + parser = argparse.ArgumentParser(description="Run registered benchmarks.") + parser.add_argument( + "--name", + type=str, + default=None, + help=( + "Name of the benchmark to run. If not provided, " + "all benchmarks will be run." + ), + ) + parser.add_argument( + "--iters", + type=int, + default=10, + help="Number of iterations to run each benchmark for.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="benchmark_results", + help="Directory to save benchmark results in.", + ) + args = parser.parse_args() + return main( + benchmark_name=args.name, + iters=args.iters, + output_dir=pathlib.Path(args.output_dir), + ) + + +if __name__ == "__main__": + sys.exit(cli()) diff --git a/torch_harmonics/benchmark/sht.py b/torch_harmonics/benchmark/sht.py new file mode 100644 index 00000000..98b14625 --- /dev/null +++ b/torch_harmonics/benchmark/sht.py @@ -0,0 +1,83 @@ +import abc +from typing import Self, final + +import torch + +from torch_harmonics.benchmark.benchmark import ( + BenchmarkABC, + TensorDict, + register_benchmark, +) +from torch_harmonics.benchmark.timer import Timer +from torch_harmonics.sht import InverseRealSHT, RealSHT + + +def _get_device(): + return torch.device("cuda", torch.cuda.current_device()) + + +class RealSHTBenchmark(BenchmarkABC): + + @final + def __init__(self, forward_sht: RealSHT, x: torch.Tensor): + self.forward_sht = forward_sht + self.x = x + + @classmethod + @abc.abstractmethod + def new(cls) -> "RealSHTBenchmark": ... + + @classmethod + @final + def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: + device = _get_device() + x = torch.randn(B, H, L, device=device) + forward_sht = RealSHT(nlat=H, nlon=L).to(device) + return cls(forward_sht=forward_sht, x=x) + + @final + def run_instance(self, timer: Timer) -> TensorDict: + result = self.forward_sht(self.x) + return {"output": result.detach()} + + +@register_benchmark("real_sht_1deg") +class RealSHTBenchmark1Degree(RealSHTBenchmark): + + @classmethod + def new(cls) -> "RealSHTBenchmark1Degree": + return cls.new_with_shape(B=4096, H=180, L=360) + + +class InverseRealSHTBenchmark(BenchmarkABC): + + @final + def __init__(self, inverse_sht: InverseRealSHT, x_hat: torch.Tensor): + self.inverse_sht = inverse_sht + self.x_hat = x_hat + + @classmethod + @abc.abstractmethod + def new(cls) -> "InverseRealSHTBenchmark": ... + + @classmethod + @final + def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: + device = _get_device() + x = torch.randn(B, H, L, device=device) + forward_sht = RealSHT(nlat=H, nlon=L).to(device) + x_hat = forward_sht(x) + inverse_sht = InverseRealSHT(nlat=H, nlon=L).to(device) + return cls(inverse_sht=inverse_sht, x_hat=x_hat) + + @final + def run_instance(self, timer: Timer) -> TensorDict: + result = self.inverse_sht(self.x_hat) + return {"output": result.detach()} + +@register_benchmark("inverse_real_sht_1deg") +class InverseRealSHTBenchmark1Degree(InverseRealSHTBenchmark): + + @classmethod + def new(cls) -> "InverseRealSHTBenchmark1Degree": + return cls.new_with_shape(B=4096, H=180, L=360) diff --git a/torch_harmonics/benchmark/timer.py b/torch_harmonics/benchmark/timer.py new file mode 100644 index 00000000..44bd5190 --- /dev/null +++ b/torch_harmonics/benchmark/timer.py @@ -0,0 +1,203 @@ +import collections +import dataclasses +import time +from typing import Literal, Protocol, Self + +import torch + + +@dataclasses.dataclass +class TimerResult: + count: int + avg_time: float + children: dict[str, "TimerResult"] + + def get_logs(self, max_depth: int) -> dict[str, float]: + logs = { + "avg_time": self.avg_time, + } + if max_depth > 0: + for child_name, child in self.children.items(): + for log_name, value in child.get_logs(max_depth=max_depth - 1).items(): + logs[f"{child_name}/{log_name}"] = value + return logs + + def assert_close(self, other: "TimerResult", rtol=0.02, children_rtol=0.02) -> None: + if self.count != other.count: + raise AssertionError(f"count differ: {self.count} vs {other.count}") + if not torch.isclose( + torch.tensor(self.avg_time), torch.tensor(other.avg_time), rtol=rtol + ): + raise AssertionError( + f"avg_time differ: {self.avg_time} vs " + f"{other.avg_time} given rtol={rtol}" + ) + if self.children.keys() != other.children.keys(): + raise AssertionError( + f"children keys differ: {self.children.keys()} vs " + f"{other.children.keys()}" + ) + for key in self.children.keys(): + try: + self.children[key].assert_close( + other.children[key], rtol=children_rtol, children_rtol=children_rtol + ) + except AssertionError as e: + raise AssertionError(f"child '{key}' differ: {e}") from e + + +class Timer(Protocol): + def child(self, name: str) -> Self: ... + def __enter__(self) -> Self: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: ... + + +class NullTimer: + def child(self, name: str) -> "NullTimer": + return self + + def __enter__(self) -> "NullTimer": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: + return False + + +_: Timer = NullTimer() +del _ + + +class EventPair: + def __init__(self): + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self._stream = None + self._start_recorded = False + self._end_recorded = False + + def record_start(self): + if self._start_recorded: + raise RuntimeError( + "record_start has already been called on this EventPair." + ) + self._stream = torch.cuda.current_stream() + self.start.record(self._stream) + self._start_recorded = True + + def record_end(self): + if not self._start_recorded: + raise RuntimeError("record_start must be called before record_end") + if self._end_recorded: + raise RuntimeError("record_end has already been called on this EventPair.") + if self._stream is None: + raise RuntimeError("record_start must be called before record_end") + self.end.record(self._stream) + self._end_recorded = True + + def elapsed_time_ms(self) -> float: + if not self._start_recorded or not self._end_recorded: + raise RuntimeError( + "Both record_start and record_end must be called " + "before elapsed_time_ms can be called." + ) + return self.start.elapsed_time(self.end) + + +class CPUEventPair: + def __init__(self): + self.start_time = None + self.end_time = None + + def record_start(self): + if self.start_time is not None: + raise RuntimeError( + "record_start has already been called on this CPUEventPair." + ) + self.start_time = time.time() + + def record_end(self): + if self.start_time is None: + raise RuntimeError("record_start must be called before record_end") + if self.end_time is not None: + raise RuntimeError( + "record_end has already been called on this CPUEventPair." + ) + self.end_time = time.time() + + def elapsed_time_ms(self) -> float: + if self.start_time is None or self.end_time is None: + raise RuntimeError( + "Both record_start and record_end must be called " + "before elapsed_time_ms can be called." + ) + return (self.end_time - self.start_time) * 1000 + + +class CUDATimer: + def __init__(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot use CUDATimer.") + self._children: collections.defaultdict[str, CUDATimer] = ( + collections.defaultdict(CUDATimer) + ) + self._event_pairs: list[EventPair] = [] + self._entered = False + self._result: TimerResult | None = None + + @classmethod + def new_if_available(cls) -> "CUDATimer | NullTimer": + if torch.cuda.is_available(): + return cls() + else: + return NullTimer() + + def __enter__(self): + if self._entered: + raise RuntimeError("CUDATimer is already entered.") + self._entered = True + self._event_pairs.append(EventPair()) + self._event_pairs[-1].record_start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._event_pairs: + raise RuntimeError("CUDATimer context was not properly entered.") + self._event_pairs[-1].record_end() + self._entered = False + return False + + def child(self, name: str) -> "CUDATimer": + if not self._entered: + raise RuntimeError( + "CUDATimer child cannot be used before entering the timer." + ) + return self._children[name] + + @property + def _avg_time(self) -> float: + if len(self._event_pairs) == 0: + raise RuntimeError( + "CUDATimer report cannot be generated before entering the timer." + ) + total_time = sum( + event_pair.elapsed_time_ms() for event_pair in self._event_pairs + ) + return total_time / len(self._event_pairs) + + def _child_reports(self) -> dict[str, TimerResult]: + return {name: child.result for name, child in self._children.items()} + + @property + def result(self) -> TimerResult: + if self._result is None: + torch.cuda.synchronize() + self._result = TimerResult( + count=len(self._event_pairs), + avg_time=self._avg_time, + children=self._child_reports(), + ) + return self._result + + +__: type[Timer] = CUDATimer +del __ From ad511f1ac4be3187cb8a5090e72643ca07c786d2 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Mar 2026 15:40:32 +0000 Subject: [PATCH 2/8] Add DiscreteContinuousConvS2 benchmark using torch sparse path Register a disco_conv_s2_torch_1deg benchmark at 1-degree resolution (B=4, 4 channels, 180x360) using the non-optimized torch contraction path, which does not require the custom CUDA extension. Co-Authored-By: Claude Opus 4.6 --- torch_harmonics/benchmark/__init__.py | 1 + torch_harmonics/benchmark/disco.py | 68 +++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 torch_harmonics/benchmark/disco.py diff --git a/torch_harmonics/benchmark/__init__.py b/torch_harmonics/benchmark/__init__.py index 1baa3686..0c6823a6 100644 --- a/torch_harmonics/benchmark/__init__.py +++ b/torch_harmonics/benchmark/__init__.py @@ -13,3 +13,4 @@ # Import to trigger registration of built-in benchmarks. import torch_harmonics.benchmark.sht # noqa: F401 +import torch_harmonics.benchmark.disco # noqa: F401 diff --git a/torch_harmonics/benchmark/disco.py b/torch_harmonics/benchmark/disco.py new file mode 100644 index 00000000..7b51c268 --- /dev/null +++ b/torch_harmonics/benchmark/disco.py @@ -0,0 +1,68 @@ +import abc +from typing import Self, final + +import torch + +from torch_harmonics.benchmark.benchmark import ( + BenchmarkABC, + TensorDict, + register_benchmark, +) +from torch_harmonics.benchmark.timer import Timer +from torch_harmonics.disco import DiscreteContinuousConvS2 + + +def _get_device(): + return torch.device("cuda", torch.cuda.current_device()) + + +class DiscreteContinuousConvS2Benchmark(BenchmarkABC): + + @final + def __init__(self, conv: DiscreteContinuousConvS2, x: torch.Tensor): + self.conv = conv + self.x = x + + @classmethod + @abc.abstractmethod + def new(cls) -> "DiscreteContinuousConvS2Benchmark": ... + + @classmethod + @final + def new_with_shape( + cls: type[Self], + B: int, + in_channels: int, + out_channels: int, + nlat: int, + nlon: int, + kernel_shape: int = 3, + ) -> Self: + device = _get_device() + theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat - 1) + conv = DiscreteContinuousConvS2( + in_channels=in_channels, + out_channels=out_channels, + in_shape=(nlat, nlon), + out_shape=(nlat, nlon), + kernel_shape=kernel_shape, + theta_cutoff=theta_cutoff, + optimized_kernel=False, + ).to(device) + x = torch.randn(B, in_channels, nlat, nlon, device=device) + return cls(conv=conv, x=x) + + @final + def run_instance(self, timer: Timer) -> TensorDict: + result = self.conv(self.x) + return {"output": result.detach()} + + +@register_benchmark("disco_conv_s2_torch_1deg") +class DiscreteContinuousConvS2TorchBenchmark1Degree(DiscreteContinuousConvS2Benchmark): + + @classmethod + def new(cls) -> "DiscreteContinuousConvS2TorchBenchmark1Degree": + return cls.new_with_shape( + B=4, in_channels=4, out_channels=4, nlat=180, nlon=360, + ) From 7247eb735fd3a4c57ec465800b9fa58739868d55 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Mar 2026 15:44:08 +0000 Subject: [PATCH 3/8] Add hardware-dependent batch size scaling for benchmarks Introduce hardware.py with a device-name-to-scale-factor lookup table so benchmark batch sizes adapt to different GPUs. Base batch sizes are tuned for Tesla T4 (factor 1.0). Unknown devices default to 1.0 with a warning to add an entry for their hardware. Co-Authored-By: Claude Opus 4.6 --- torch_harmonics/benchmark/disco.py | 9 ++--- torch_harmonics/benchmark/hardware.py | 52 +++++++++++++++++++++++++++ torch_harmonics/benchmark/sht.py | 13 +++---- 3 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 torch_harmonics/benchmark/hardware.py diff --git a/torch_harmonics/benchmark/disco.py b/torch_harmonics/benchmark/disco.py index 7b51c268..323069b2 100644 --- a/torch_harmonics/benchmark/disco.py +++ b/torch_harmonics/benchmark/disco.py @@ -8,14 +8,11 @@ TensorDict, register_benchmark, ) +from torch_harmonics.benchmark.hardware import get_device, scale_batch_size from torch_harmonics.benchmark.timer import Timer from torch_harmonics.disco import DiscreteContinuousConvS2 -def _get_device(): - return torch.device("cuda", torch.cuda.current_device()) - - class DiscreteContinuousConvS2Benchmark(BenchmarkABC): @final @@ -38,7 +35,7 @@ def new_with_shape( nlon: int, kernel_shape: int = 3, ) -> Self: - device = _get_device() + device = get_device() theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat - 1) conv = DiscreteContinuousConvS2( in_channels=in_channels, @@ -64,5 +61,5 @@ class DiscreteContinuousConvS2TorchBenchmark1Degree(DiscreteContinuousConvS2Benc @classmethod def new(cls) -> "DiscreteContinuousConvS2TorchBenchmark1Degree": return cls.new_with_shape( - B=4, in_channels=4, out_channels=4, nlat=180, nlon=360, + B=scale_batch_size(4), in_channels=4, out_channels=4, nlat=180, nlon=360, ) diff --git a/torch_harmonics/benchmark/hardware.py b/torch_harmonics/benchmark/hardware.py new file mode 100644 index 00000000..9c3378c2 --- /dev/null +++ b/torch_harmonics/benchmark/hardware.py @@ -0,0 +1,52 @@ +import logging + +import torch + +logger = logging.getLogger(__name__) + +# Batch size scale factors relative to Tesla T4 (the default baseline). +# To add a new GPU, add an entry mapping its device name (as returned by +# torch.cuda.get_device_properties(...).name) to a float scale factor. +# Values > 1.0 mean the GPU is faster than a T4 and can use larger batches; +# values < 1.0 mean it is slower. +_BATCH_SIZE_FACTORS: dict[str, float] = { + "Tesla T4": 1.0, +} + +_DEFAULT_BATCH_SIZE_FACTOR = 1.0 + + +def get_device() -> torch.device: + return torch.device("cuda", torch.cuda.current_device()) + + +def get_batch_size_factor() -> float: + """Return a hardware-dependent scale factor for benchmark batch sizes. + + Benchmarks define a base batch size tuned for a Tesla T4. This function + returns a multiplier so that benchmarks take a similar wall-clock time + on other hardware. If the batch size is too small, the GPU will not be fully + occupied, and the benchmarks cannot be used to tune performance. + + Unknown devices fall back to the T4 default (1.0). + """ + if not torch.cuda.is_available(): + return _DEFAULT_BATCH_SIZE_FACTOR + name = torch.cuda.get_device_properties(torch.cuda.current_device()).name + factor = _BATCH_SIZE_FACTORS.get(name) + if factor is None: + logger.warning( + f"Unknown GPU '{name}', using default batch size factor " + f"{_DEFAULT_BATCH_SIZE_FACTOR}. Add an entry to " + f"_BATCH_SIZE_FACTORS in hardware.py to tune for this device." + ) + return _DEFAULT_BATCH_SIZE_FACTOR + return factor + + +def scale_batch_size(base: int) -> int: + """Scale a base batch size (tuned for Tesla T4) by the hardware factor. + + Always returns at least 1. + """ + return max(1, round(base * get_batch_size_factor())) diff --git a/torch_harmonics/benchmark/sht.py b/torch_harmonics/benchmark/sht.py index 98b14625..49a93dcc 100644 --- a/torch_harmonics/benchmark/sht.py +++ b/torch_harmonics/benchmark/sht.py @@ -8,14 +8,11 @@ TensorDict, register_benchmark, ) +from torch_harmonics.benchmark.hardware import get_device, scale_batch_size from torch_harmonics.benchmark.timer import Timer from torch_harmonics.sht import InverseRealSHT, RealSHT -def _get_device(): - return torch.device("cuda", torch.cuda.current_device()) - - class RealSHTBenchmark(BenchmarkABC): @final @@ -30,7 +27,7 @@ def new(cls) -> "RealSHTBenchmark": ... @classmethod @final def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: - device = _get_device() + device = get_device() x = torch.randn(B, H, L, device=device) forward_sht = RealSHT(nlat=H, nlon=L).to(device) return cls(forward_sht=forward_sht, x=x) @@ -46,7 +43,7 @@ class RealSHTBenchmark1Degree(RealSHTBenchmark): @classmethod def new(cls) -> "RealSHTBenchmark1Degree": - return cls.new_with_shape(B=4096, H=180, L=360) + return cls.new_with_shape(B=scale_batch_size(4096), H=180, L=360) class InverseRealSHTBenchmark(BenchmarkABC): @@ -63,7 +60,7 @@ def new(cls) -> "InverseRealSHTBenchmark": ... @classmethod @final def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: - device = _get_device() + device = get_device() x = torch.randn(B, H, L, device=device) forward_sht = RealSHT(nlat=H, nlon=L).to(device) x_hat = forward_sht(x) @@ -80,4 +77,4 @@ class InverseRealSHTBenchmark1Degree(InverseRealSHTBenchmark): @classmethod def new(cls) -> "InverseRealSHTBenchmark1Degree": - return cls.new_with_shape(B=4096, H=180, L=360) + return cls.new_with_shape(B=scale_batch_size(4096), H=180, L=360) From 71eaefd9798f750bcea6c9899b0ad133a90d194e Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 18 Mar 2026 08:40:46 -0700 Subject: [PATCH 4/8] adding batch size and device ovveride to CLA --- torch_harmonics/benchmark/__init__.py | 1 + torch_harmonics/benchmark/benchmark.py | 70 ++++++++++++++++++++++---- torch_harmonics/benchmark/disco.py | 20 +++++--- torch_harmonics/benchmark/hardware.py | 29 +++++++++-- torch_harmonics/benchmark/run.py | 46 +++++++++++++---- torch_harmonics/benchmark/sht.py | 49 ++++++++++++++---- torch_harmonics/benchmark/timer.py | 64 ++++++++++++++++++++++- 7 files changed, 238 insertions(+), 41 deletions(-) diff --git a/torch_harmonics/benchmark/__init__.py b/torch_harmonics/benchmark/__init__.py index 0c6823a6..b201b2be 100644 --- a/torch_harmonics/benchmark/__init__.py +++ b/torch_harmonics/benchmark/__init__.py @@ -5,6 +5,7 @@ register_benchmark, ) from torch_harmonics.benchmark.timer import ( + CPUTimer, CUDATimer, NullTimer, Timer, diff --git a/torch_harmonics/benchmark/benchmark.py b/torch_harmonics/benchmark/benchmark.py index 00ddf1bb..aa50685e 100644 --- a/torch_harmonics/benchmark/benchmark.py +++ b/torch_harmonics/benchmark/benchmark.py @@ -5,8 +5,10 @@ import torch +from torch_harmonics.benchmark.hardware import get_device from torch_harmonics.benchmark.timer import ( CPUEventPair, + CPUTimer, CUDATimer, NullTimer, Timer, @@ -18,17 +20,19 @@ @dataclasses.dataclass class BenchmarkResult: + phase: str + device: str timer: TimerResult cpu_time: float def __repr__(self) -> str: - return f"BenchmarkResult(timer={self.timer}, cpu_time={self.cpu_time})" + return f"BenchmarkResult(phase={self.phase}, device={self.device}, timer={self.timer}, cpu_time={self.cpu_time})" def asdict(self) -> dict: return dataclasses.asdict(self) def get_logs(self, max_depth: int) -> dict[str, float]: - logs = {"cpu_time": self.cpu_time} + logs = {"phase": self.phase, "device": self.device, "cpu_time": self.cpu_time} logs.update(self.timer.get_logs(max_depth=max_depth)) return logs @@ -44,30 +48,74 @@ def new(cls: type[Self]) -> Self: pass @classmethod - def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available, cannot run benchmark.") + def _make_timer(cls) -> CUDATimer | CPUTimer: + if torch.cuda.is_available(): + return CUDATimer() + return CPUTimer() + + @classmethod + def _sync(cls) -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + @classmethod + def run_forward_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: null_timer = NullTimer() benchmark = cls.new() for _ in range(warmup): - benchmark.run_instance(null_timer) - timer = CUDATimer() + benchmark.run_instance_forward(null_timer) + timer = cls._make_timer() cpu_timer = CPUEventPair() cpu_timer.record_start() for _ in range(iters): with timer: - benchmark.run_instance(timer) - torch.cuda.synchronize() + benchmark.run_instance_forward(timer) + cls._sync() cpu_timer.record_end() return BenchmarkResult( + phase="forward", + device=str(get_device()), timer=timer.result, cpu_time=cpu_timer.elapsed_time_ms(), ) + @classmethod + def run_backward_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: + null_timer = NullTimer() + benchmark = cls.new() + benchmark.run_instance_forward(null_timer) + for _ in range(warmup): + benchmark.run_instance_backward(null_timer) + timer = cls._make_timer() + cpu_timer = CPUEventPair() + cpu_timer.record_start() + for _ in range(iters): + with timer: + benchmark.run_instance_backward(timer) + cls._sync() + cpu_timer.record_end() + return BenchmarkResult( + phase="backward", + device=str(get_device()), + timer=timer.result, + cpu_time=cpu_timer.elapsed_time_ms(), + ) + + @abc.abstractmethod + def run_instance_forward(self: Self, timer: Timer) -> TensorDict: + """ + Run the benchmark in backward pass. This will be called multiple times, + and should return a TensorDict of results. + + This must not mutate any state on self, since the same instance may be + used across multiple iterations. + """ + pass + @abc.abstractmethod - def run_instance(self: Self, timer: Timer) -> TensorDict: + def run_instance_backward(self: Self, timer: Timer) -> TensorDict: """ - Run the benchmark. This will be called multiple times, + Run the benchmark in forward pass. This will be called multiple times, and should return a TensorDict of results. This must not mutate any state on self, since the same instance may be diff --git a/torch_harmonics/benchmark/disco.py b/torch_harmonics/benchmark/disco.py index 323069b2..05814011 100644 --- a/torch_harmonics/benchmark/disco.py +++ b/torch_harmonics/benchmark/disco.py @@ -35,24 +35,30 @@ def new_with_shape( nlon: int, kernel_shape: int = 3, ) -> Self: - device = get_device() - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat - 1) + cls.device = get_device() conv = DiscreteContinuousConvS2( in_channels=in_channels, out_channels=out_channels, in_shape=(nlat, nlon), out_shape=(nlat, nlon), kernel_shape=kernel_shape, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=False, - ).to(device) - x = torch.randn(B, in_channels, nlat, nlon, device=device) + ).to(cls.device) + x = torch.randn(B, in_channels, nlat, nlon, dtype=torch.float32, device=cls.device) return cls(conv=conv, x=x) @final - def run_instance(self, timer: Timer) -> TensorDict: + def run_instance_forward(self, timer: Timer) -> TensorDict: result = self.conv(self.x) - return {"output": result.detach()} + self.output = result + return {"outputs": result.detach()} + + @final + def run_instance_backward(self, timer: Timer) -> TensorDict: + g = torch.randn_like(self.output) + self.output.backward(g, retain_graph=True) + return {"gradient": self.x.grad} @register_benchmark("disco_conv_s2_torch_1deg") diff --git a/torch_harmonics/benchmark/hardware.py b/torch_harmonics/benchmark/hardware.py index 9c3378c2..6abf0d92 100644 --- a/torch_harmonics/benchmark/hardware.py +++ b/torch_harmonics/benchmark/hardware.py @@ -15,9 +15,28 @@ _DEFAULT_BATCH_SIZE_FACTOR = 1.0 +_device: torch.device | None = None +_batch_size_override: int | None = None + + +def set_batch_size(batch_size: int) -> None: + """Override the batch size used by all benchmarks, bypassing hardware scaling.""" + global _batch_size_override + _batch_size_override = batch_size + + +def set_device(device: str | torch.device) -> None: + """Override the device used by all benchmarks.""" + global _device + _device = torch.device(device) + def get_device() -> torch.device: - return torch.device("cuda", torch.cuda.current_device()) + if _device is not None: + return _device + if torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + return torch.device("cpu") def get_batch_size_factor() -> float: @@ -45,8 +64,12 @@ def get_batch_size_factor() -> float: def scale_batch_size(base: int) -> int: - """Scale a base batch size (tuned for Tesla T4) by the hardware factor. + """Return the batch size to use for a benchmark. - Always returns at least 1. + If a global override has been set via set_batch_size(), that value is + returned directly. Otherwise the base is scaled by the hardware factor + (tuned relative to a Tesla T4). Always returns at least 1. """ + if _batch_size_override is not None: + return max(1, _batch_size_override) return max(1, round(base * get_batch_size_factor())) diff --git a/torch_harmonics/benchmark/run.py b/torch_harmonics/benchmark/run.py index a3d8f067..f6f5a153 100644 --- a/torch_harmonics/benchmark/run.py +++ b/torch_harmonics/benchmark/run.py @@ -9,6 +9,7 @@ import torch from torch_harmonics.benchmark.benchmark import get_benchmarks +from torch_harmonics.benchmark.hardware import set_batch_size, set_device _GIT_COMMIT: str | None = None @@ -52,7 +53,12 @@ def main( benchmark_name: str | None, iters: int, output_dir: pathlib.Path, + device: str, + batch_size: int | None, ) -> int: + set_device(device) + if batch_size is not None: + set_batch_size(batch_size) output_dir.mkdir(parents=True, exist_ok=True) device_name = get_device_name() safe_device_name = device_name.replace(" ", "_").replace("/", "_").lower() @@ -70,22 +76,26 @@ def main( else: benchmarks_to_run = benchmarks - def get_filename(name, extension) -> pathlib.Path: + def get_filename(name, phase, extension) -> pathlib.Path: safe_name = name.replace("/", "_").replace(".", "_").lower() return ( output_dir - / f"{safe_name}_{safe_device_name}_{get_git_commit()}.{extension}" + / f"{safe_name}_{phase}_{safe_device_name}_{get_git_commit()}.{extension}" ) for name, cls in benchmarks_to_run.items(): - logging.info(f"Running benchmark: {name}") - result = cls.run_benchmark(iters=iters) - result_data = json.dumps(dataclasses.asdict(result), indent=2) - logging.info(f"Result: {result_data}") - json_path = get_filename(name, "json") - with open(json_path, "w") as f: - logging.info(f"Saving result json to {f.name}") - f.write(result_data) + for run_fn, phase in [ + (cls.run_forward_benchmark, "forward"), + (cls.run_backward_benchmark, "backward"), + ]: + logging.info(f"Running {phase} benchmark: {name}") + result = run_fn(iters=iters) + result_data = json.dumps(dataclasses.asdict(result), indent=2) + logging.info(f"Result:\n{result_data}") + json_path = get_filename(name, phase, "json") + with open(json_path, "w") as f: + logging.info(f"Saving result json to {f.name}") + f.write(result_data) return 0 @@ -116,11 +126,27 @@ def cli() -> int: default="benchmark_results", help="Directory to save benchmark results in.", ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run benchmarks on, e.g. 'cpu', 'cuda', 'cuda:1'. " + "Defaults to 'cuda' if available, otherwise 'cpu'.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Override the batch size for all benchmarks. If not set, each " + "benchmark uses its hardware-scaled default.", + ) args = parser.parse_args() return main( benchmark_name=args.name, iters=args.iters, output_dir=pathlib.Path(args.output_dir), + device=args.device, + batch_size=args.batch_size, ) diff --git a/torch_harmonics/benchmark/sht.py b/torch_harmonics/benchmark/sht.py index 49a93dcc..9411fc7e 100644 --- a/torch_harmonics/benchmark/sht.py +++ b/torch_harmonics/benchmark/sht.py @@ -27,17 +27,25 @@ def new(cls) -> "RealSHTBenchmark": ... @classmethod @final def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: - device = get_device() - x = torch.randn(B, H, L, device=device) - forward_sht = RealSHT(nlat=H, nlon=L).to(device) + cls.device = get_device() + x = torch.randn(B, H, L, device=cls.device) + x.requires_grad = True + forward_sht = RealSHT(nlat=H, nlon=L).to(cls.device) return cls(forward_sht=forward_sht, x=x) @final - def run_instance(self, timer: Timer) -> TensorDict: + def run_instance_forward(self, timer: Timer) -> TensorDict: result = self.forward_sht(self.x) + self.output = result return {"output": result.detach()} + @final + def run_instance_backward(self, timer: Timer) -> TensorDict: + g = torch.randn_like(self.output) + self.output.backward(g, retain_graph=True) + return {"gradient": self.x.grad} +# predefined benchmarks @register_benchmark("real_sht_1deg") class RealSHTBenchmark1Degree(RealSHTBenchmark): @@ -45,6 +53,13 @@ class RealSHTBenchmark1Degree(RealSHTBenchmark): def new(cls) -> "RealSHTBenchmark1Degree": return cls.new_with_shape(B=scale_batch_size(4096), H=180, L=360) +@register_benchmark("real_sht_quarter_deg") +class RealSHTBenchmarkQuarterDegree(RealSHTBenchmark): + + @classmethod + def new(cls) -> "RealSHTBenchmarkQuarterDegree": + return cls.new_with_shape(B=scale_batch_size(1), H=721, L=1440) + class InverseRealSHTBenchmark(BenchmarkABC): @@ -60,21 +75,37 @@ def new(cls) -> "InverseRealSHTBenchmark": ... @classmethod @final def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: - device = get_device() - x = torch.randn(B, H, L, device=device) - forward_sht = RealSHT(nlat=H, nlon=L).to(device) + cls.device = get_device() + x = torch.randn(B, H, L, device=cls.device) + forward_sht = RealSHT(nlat=H, nlon=L).to(cls.device) x_hat = forward_sht(x) - inverse_sht = InverseRealSHT(nlat=H, nlon=L).to(device) + x_hat.requires_grad = True + inverse_sht = InverseRealSHT(nlat=H, nlon=L).to(cls.device) return cls(inverse_sht=inverse_sht, x_hat=x_hat) @final - def run_instance(self, timer: Timer) -> TensorDict: + def run_instance_forward(self, timer: Timer) -> TensorDict: result = self.inverse_sht(self.x_hat) + self.output = result return {"output": result.detach()} + @final + def run_instance_backward(self, timer: Timer) -> TensorDict: + g = torch.randn_like(self.output) + self.output.backward(g, retain_graph=True) + return {"gradient": self.x_hat.grad} + +# predefined benchmarks @register_benchmark("inverse_real_sht_1deg") class InverseRealSHTBenchmark1Degree(InverseRealSHTBenchmark): @classmethod def new(cls) -> "InverseRealSHTBenchmark1Degree": return cls.new_with_shape(B=scale_batch_size(4096), H=180, L=360) + +@register_benchmark("inverse_real_sht_quarter_deg") +class InverseRealSHTBenchmarkQuarterDegree(InverseRealSHTBenchmark): + + @classmethod + def new(cls) -> "InverseRealSHTBenchmarkQuarterDegree": + return cls.new_with_shape(B=scale_batch_size(1), H=721, L=1440) diff --git a/torch_harmonics/benchmark/timer.py b/torch_harmonics/benchmark/timer.py index 44bd5190..eafd2e9e 100644 --- a/torch_harmonics/benchmark/timer.py +++ b/torch_harmonics/benchmark/timer.py @@ -10,11 +10,13 @@ class TimerResult: count: int avg_time: float - children: dict[str, "TimerResult"] + unit: str = "ms" + children: dict[str, "TimerResult"] = dataclasses.field(default_factory=dict) def get_logs(self, max_depth: int) -> dict[str, float]: logs = { "avg_time": self.avg_time, + "unit": self.unit, } if max_depth > 0: for child_name, child in self.children.items(): @@ -201,3 +203,63 @@ def result(self) -> TimerResult: __: type[Timer] = CUDATimer del __ + + +class CPUTimer: + """Wall-clock timer with the same interface as CUDATimer, for CPU benchmarks.""" + + def __init__(self): + self._children: collections.defaultdict[str, "CPUTimer"] = ( + collections.defaultdict(CPUTimer) + ) + self._event_pairs: list[CPUEventPair] = [] + self._entered = False + self._result: TimerResult | None = None + + def __enter__(self): + if self._entered: + raise RuntimeError("CPUTimer is already entered.") + self._entered = True + self._event_pairs.append(CPUEventPair()) + self._event_pairs[-1].record_start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._event_pairs: + raise RuntimeError("CPUTimer context was not properly entered.") + self._event_pairs[-1].record_end() + self._entered = False + return False + + def child(self, name: str) -> "CPUTimer": + if not self._entered: + raise RuntimeError( + "CPUTimer child cannot be used before entering the timer." + ) + return self._children[name] + + @property + def _avg_time(self) -> float: + if not self._event_pairs: + raise RuntimeError( + "CPUTimer report cannot be generated before entering the timer." + ) + total_time = sum(ep.elapsed_time_ms() for ep in self._event_pairs) + return total_time / len(self._event_pairs) + + def _child_reports(self) -> dict[str, TimerResult]: + return {name: child.result for name, child in self._children.items()} + + @property + def result(self) -> TimerResult: + if self._result is None: + self._result = TimerResult( + count=len(self._event_pairs), + avg_time=self._avg_time, + children=self._child_reports(), + ) + return self._result + + +_: type[Timer] = CPUTimer +del _ From eedbbac50c405bcb0891e1008eec74872a549e19 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 18:08:35 +0000 Subject: [PATCH 5/8] Revert separate forward/backward benchmarks; use child timers instead Replace run_instance_forward/run_instance_backward with a single run_instance that uses timer.child("forward") and timer.child("backward") to time phases within one call. This lets the backward pass reuse the forward computation and keeps the benchmark API simpler. Also fix gradient accumulation across iterations and remove unnecessary retain_graph=True. Co-Authored-By: Claude Opus 4.6 (1M context) --- torch_harmonics/benchmark/benchmark.py | 50 +++++-------------------- torch_harmonics/benchmark/disco.py | 25 ++++++------- torch_harmonics/benchmark/run.py | 24 +++++------- torch_harmonics/benchmark/sht.py | 51 ++++++++++++-------------- 4 files changed, 55 insertions(+), 95 deletions(-) diff --git a/torch_harmonics/benchmark/benchmark.py b/torch_harmonics/benchmark/benchmark.py index aa50685e..21b7beaf 100644 --- a/torch_harmonics/benchmark/benchmark.py +++ b/torch_harmonics/benchmark/benchmark.py @@ -20,19 +20,18 @@ @dataclasses.dataclass class BenchmarkResult: - phase: str device: str timer: TimerResult cpu_time: float def __repr__(self) -> str: - return f"BenchmarkResult(phase={self.phase}, device={self.device}, timer={self.timer}, cpu_time={self.cpu_time})" + return f"BenchmarkResult(device={self.device}, timer={self.timer}, cpu_time={self.cpu_time})" def asdict(self) -> dict: return dataclasses.asdict(self) def get_logs(self, max_depth: int) -> dict[str, float]: - logs = {"phase": self.phase, "device": self.device, "cpu_time": self.cpu_time} + logs = {"device": self.device, "cpu_time": self.cpu_time} logs.update(self.timer.get_logs(max_depth=max_depth)) return logs @@ -59,64 +58,33 @@ def _sync(cls) -> None: torch.cuda.synchronize() @classmethod - def run_forward_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: + def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: null_timer = NullTimer() benchmark = cls.new() for _ in range(warmup): - benchmark.run_instance_forward(null_timer) + benchmark.run_instance(null_timer) timer = cls._make_timer() cpu_timer = CPUEventPair() cpu_timer.record_start() for _ in range(iters): with timer: - benchmark.run_instance_forward(timer) + benchmark.run_instance(timer) cls._sync() cpu_timer.record_end() return BenchmarkResult( - phase="forward", - device=str(get_device()), - timer=timer.result, - cpu_time=cpu_timer.elapsed_time_ms(), - ) - - @classmethod - def run_backward_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: - null_timer = NullTimer() - benchmark = cls.new() - benchmark.run_instance_forward(null_timer) - for _ in range(warmup): - benchmark.run_instance_backward(null_timer) - timer = cls._make_timer() - cpu_timer = CPUEventPair() - cpu_timer.record_start() - for _ in range(iters): - with timer: - benchmark.run_instance_backward(timer) - cls._sync() - cpu_timer.record_end() - return BenchmarkResult( - phase="backward", device=str(get_device()), timer=timer.result, cpu_time=cpu_timer.elapsed_time_ms(), ) @abc.abstractmethod - def run_instance_forward(self: Self, timer: Timer) -> TensorDict: + def run_instance(self: Self, timer: Timer) -> TensorDict: """ - Run the benchmark in backward pass. This will be called multiple times, + Run the benchmark. This will be called multiple times, and should return a TensorDict of results. - This must not mutate any state on self, since the same instance may be - used across multiple iterations. - """ - pass - - @abc.abstractmethod - def run_instance_backward(self: Self, timer: Timer) -> TensorDict: - """ - Run the benchmark in forward pass. This will be called multiple times, - and should return a TensorDict of results. + Use timer.child("forward") and timer.child("backward") context + managers to separately time phases within the benchmark. This must not mutate any state on self, since the same instance may be used across multiple iterations. diff --git a/torch_harmonics/benchmark/disco.py b/torch_harmonics/benchmark/disco.py index 05814011..cfd13101 100644 --- a/torch_harmonics/benchmark/disco.py +++ b/torch_harmonics/benchmark/disco.py @@ -35,7 +35,7 @@ def new_with_shape( nlon: int, kernel_shape: int = 3, ) -> Self: - cls.device = get_device() + device = get_device() conv = DiscreteContinuousConvS2( in_channels=in_channels, out_channels=out_channels, @@ -44,21 +44,20 @@ def new_with_shape( kernel_shape=kernel_shape, theta_cutoff=None, optimized_kernel=False, - ).to(cls.device) - x = torch.randn(B, in_channels, nlat, nlon, dtype=torch.float32, device=cls.device) + ).to(device) + x = torch.randn(B, in_channels, nlat, nlon, dtype=torch.float32, device=device, requires_grad=True) return cls(conv=conv, x=x) @final - def run_instance_forward(self, timer: Timer) -> TensorDict: - result = self.conv(self.x) - self.output = result - return {"outputs": result.detach()} - - @final - def run_instance_backward(self, timer: Timer) -> TensorDict: - g = torch.randn_like(self.output) - self.output.backward(g, retain_graph=True) - return {"gradient": self.x.grad} + def run_instance(self, timer: Timer) -> TensorDict: + if self.x.grad is not None: + self.x.grad.zero_() + with timer.child("forward"): + result = self.conv(self.x) + with timer.child("backward"): + g = torch.randn_like(result) + result.backward(g) + return {"output": result.detach()} @register_benchmark("disco_conv_s2_torch_1deg") diff --git a/torch_harmonics/benchmark/run.py b/torch_harmonics/benchmark/run.py index f6f5a153..46834c59 100644 --- a/torch_harmonics/benchmark/run.py +++ b/torch_harmonics/benchmark/run.py @@ -76,26 +76,22 @@ def main( else: benchmarks_to_run = benchmarks - def get_filename(name, phase, extension) -> pathlib.Path: + def get_filename(name, extension) -> pathlib.Path: safe_name = name.replace("/", "_").replace(".", "_").lower() return ( output_dir - / f"{safe_name}_{phase}_{safe_device_name}_{get_git_commit()}.{extension}" + / f"{safe_name}_{safe_device_name}_{get_git_commit()}.{extension}" ) for name, cls in benchmarks_to_run.items(): - for run_fn, phase in [ - (cls.run_forward_benchmark, "forward"), - (cls.run_backward_benchmark, "backward"), - ]: - logging.info(f"Running {phase} benchmark: {name}") - result = run_fn(iters=iters) - result_data = json.dumps(dataclasses.asdict(result), indent=2) - logging.info(f"Result:\n{result_data}") - json_path = get_filename(name, phase, "json") - with open(json_path, "w") as f: - logging.info(f"Saving result json to {f.name}") - f.write(result_data) + logging.info(f"Running benchmark: {name}") + result = cls.run_benchmark(iters=iters) + result_data = json.dumps(dataclasses.asdict(result), indent=2) + logging.info(f"Result:\n{result_data}") + json_path = get_filename(name, "json") + with open(json_path, "w") as f: + logging.info(f"Saving result json to {f.name}") + f.write(result_data) return 0 diff --git a/torch_harmonics/benchmark/sht.py b/torch_harmonics/benchmark/sht.py index 9411fc7e..59998b56 100644 --- a/torch_harmonics/benchmark/sht.py +++ b/torch_harmonics/benchmark/sht.py @@ -27,24 +27,22 @@ def new(cls) -> "RealSHTBenchmark": ... @classmethod @final def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: - cls.device = get_device() - x = torch.randn(B, H, L, device=cls.device) - x.requires_grad = True - forward_sht = RealSHT(nlat=H, nlon=L).to(cls.device) + device = get_device() + x = torch.randn(B, H, L, device=device, requires_grad=True) + forward_sht = RealSHT(nlat=H, nlon=L).to(device) return cls(forward_sht=forward_sht, x=x) @final - def run_instance_forward(self, timer: Timer) -> TensorDict: - result = self.forward_sht(self.x) - self.output = result + def run_instance(self, timer: Timer) -> TensorDict: + if self.x.grad is not None: + self.x.grad.zero_() + with timer.child("forward"): + result = self.forward_sht(self.x) + with timer.child("backward"): + g = torch.randn_like(result) + result.backward(g) return {"output": result.detach()} - @final - def run_instance_backward(self, timer: Timer) -> TensorDict: - g = torch.randn_like(self.output) - self.output.backward(g, retain_graph=True) - return {"gradient": self.x.grad} - # predefined benchmarks @register_benchmark("real_sht_1deg") class RealSHTBenchmark1Degree(RealSHTBenchmark): @@ -75,26 +73,25 @@ def new(cls) -> "InverseRealSHTBenchmark": ... @classmethod @final def new_with_shape(cls: type[Self], B: int, H: int, L: int) -> Self: - cls.device = get_device() - x = torch.randn(B, H, L, device=cls.device) - forward_sht = RealSHT(nlat=H, nlon=L).to(cls.device) + device = get_device() + x = torch.randn(B, H, L, device=device) + forward_sht = RealSHT(nlat=H, nlon=L).to(device) x_hat = forward_sht(x) - x_hat.requires_grad = True - inverse_sht = InverseRealSHT(nlat=H, nlon=L).to(cls.device) + x_hat = x_hat.detach().requires_grad_(True) + inverse_sht = InverseRealSHT(nlat=H, nlon=L).to(device) return cls(inverse_sht=inverse_sht, x_hat=x_hat) @final - def run_instance_forward(self, timer: Timer) -> TensorDict: - result = self.inverse_sht(self.x_hat) - self.output = result + def run_instance(self, timer: Timer) -> TensorDict: + if self.x_hat.grad is not None: + self.x_hat.grad.zero_() + with timer.child("forward"): + result = self.inverse_sht(self.x_hat) + with timer.child("backward"): + g = torch.randn_like(result) + result.backward(g) return {"output": result.detach()} - @final - def run_instance_backward(self, timer: Timer) -> TensorDict: - g = torch.randn_like(self.output) - self.output.backward(g, retain_graph=True) - return {"gradient": self.x_hat.grad} - # predefined benchmarks @register_benchmark("inverse_real_sht_1deg") class InverseRealSHTBenchmark1Degree(InverseRealSHTBenchmark): From fe217819ce2081b1752972d07c2d94d72cd9c643 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 18:21:46 +0000 Subject: [PATCH 6/8] Add context manager support to CPUEventPair and use it in run_benchmark Address review feedback to use `with cpu_timer:` instead of explicit record_start/record_end calls. Co-Authored-By: Claude Opus 4.6 (1M context) --- torch_harmonics/benchmark/benchmark.py | 12 +++++------- torch_harmonics/benchmark/timer.py | 8 ++++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torch_harmonics/benchmark/benchmark.py b/torch_harmonics/benchmark/benchmark.py index 21b7beaf..7029dde5 100644 --- a/torch_harmonics/benchmark/benchmark.py +++ b/torch_harmonics/benchmark/benchmark.py @@ -64,13 +64,11 @@ def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: for _ in range(warmup): benchmark.run_instance(null_timer) timer = cls._make_timer() - cpu_timer = CPUEventPair() - cpu_timer.record_start() - for _ in range(iters): - with timer: - benchmark.run_instance(timer) - cls._sync() - cpu_timer.record_end() + with CPUEventPair() as cpu_timer: + for _ in range(iters): + with timer: + benchmark.run_instance(timer) + cls._sync() return BenchmarkResult( device=str(get_device()), timer=timer.result, diff --git a/torch_harmonics/benchmark/timer.py b/torch_harmonics/benchmark/timer.py index eafd2e9e..6b8b9e46 100644 --- a/torch_harmonics/benchmark/timer.py +++ b/torch_harmonics/benchmark/timer.py @@ -110,6 +110,14 @@ def __init__(self): self.start_time = None self.end_time = None + def __enter__(self): + self.record_start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.record_end() + return False + def record_start(self): if self.start_time is not None: raise RuntimeError( From 04ef6368b8abffa014c40ccd5f4095cbbcf79483 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 18:23:16 +0000 Subject: [PATCH 7/8] Remove batch_size override feature from benchmark CLI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The batch size override breaks reproducibility — output files can't distinguish whether timing differences across runs come from code changes or different CLI arguments. Each benchmark should define its own appropriate batch size via hardware scaling. Co-Authored-By: Claude Opus 4.6 (1M context) --- torch_harmonics/benchmark/hardware.py | 15 ++------------- torch_harmonics/benchmark/run.py | 13 +------------ 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/torch_harmonics/benchmark/hardware.py b/torch_harmonics/benchmark/hardware.py index 6abf0d92..684974fb 100644 --- a/torch_harmonics/benchmark/hardware.py +++ b/torch_harmonics/benchmark/hardware.py @@ -16,13 +16,6 @@ _DEFAULT_BATCH_SIZE_FACTOR = 1.0 _device: torch.device | None = None -_batch_size_override: int | None = None - - -def set_batch_size(batch_size: int) -> None: - """Override the batch size used by all benchmarks, bypassing hardware scaling.""" - global _batch_size_override - _batch_size_override = batch_size def set_device(device: str | torch.device) -> None: @@ -64,12 +57,8 @@ def get_batch_size_factor() -> float: def scale_batch_size(base: int) -> int: - """Return the batch size to use for a benchmark. + """Scale a base batch size (tuned for Tesla T4) by the hardware factor. - If a global override has been set via set_batch_size(), that value is - returned directly. Otherwise the base is scaled by the hardware factor - (tuned relative to a Tesla T4). Always returns at least 1. + Always returns at least 1. """ - if _batch_size_override is not None: - return max(1, _batch_size_override) return max(1, round(base * get_batch_size_factor())) diff --git a/torch_harmonics/benchmark/run.py b/torch_harmonics/benchmark/run.py index 46834c59..37dd1995 100644 --- a/torch_harmonics/benchmark/run.py +++ b/torch_harmonics/benchmark/run.py @@ -9,7 +9,7 @@ import torch from torch_harmonics.benchmark.benchmark import get_benchmarks -from torch_harmonics.benchmark.hardware import set_batch_size, set_device +from torch_harmonics.benchmark.hardware import set_device _GIT_COMMIT: str | None = None @@ -54,11 +54,8 @@ def main( iters: int, output_dir: pathlib.Path, device: str, - batch_size: int | None, ) -> int: set_device(device) - if batch_size is not None: - set_batch_size(batch_size) output_dir.mkdir(parents=True, exist_ok=True) device_name = get_device_name() safe_device_name = device_name.replace(" ", "_").replace("/", "_").lower() @@ -129,20 +126,12 @@ def cli() -> int: help="Device to run benchmarks on, e.g. 'cpu', 'cuda', 'cuda:1'. " "Defaults to 'cuda' if available, otherwise 'cpu'.", ) - parser.add_argument( - "--batch-size", - type=int, - default=None, - help="Override the batch size for all benchmarks. If not set, each " - "benchmark uses its hardware-scaled default.", - ) args = parser.parse_args() return main( benchmark_name=args.name, iters=args.iters, output_dir=pathlib.Path(args.output_dir), device=args.device, - batch_size=args.batch_size, ) From 5036e6c12a0e82082dcf305ff16901744abce7f1 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 21:32:29 +0000 Subject: [PATCH 8/8] Add batch size factor support for scaling benchmarks Move hardware-specific batch size scaling out of scale_batch_size and into run.py, so individual benchmarks no longer implicitly read global state. Each benchmark now runs with multiple user-configurable batch size factors (default 1x and 2x), with the factor included in the output filename. The --batch-size-factors CLI arg controls which factors to use. Co-Authored-By: Claude Opus 4.6 (1M context) --- torch_harmonics/benchmark/benchmark.py | 11 ++++++-- torch_harmonics/benchmark/disco.py | 6 ++-- torch_harmonics/benchmark/hardware.py | 6 ++-- torch_harmonics/benchmark/run.py | 39 ++++++++++++++++++-------- torch_harmonics/benchmark/sht.py | 20 ++++++------- 5 files changed, 52 insertions(+), 30 deletions(-) diff --git a/torch_harmonics/benchmark/benchmark.py b/torch_harmonics/benchmark/benchmark.py index 7029dde5..7944fbda 100644 --- a/torch_harmonics/benchmark/benchmark.py +++ b/torch_harmonics/benchmark/benchmark.py @@ -39,10 +39,15 @@ def get_logs(self, max_depth: int) -> dict[str, float]: class BenchmarkABC(abc.ABC): @classmethod @abc.abstractmethod - def new(cls: type[Self]) -> Self: + def new(cls: type[Self], batch_size_factor: float = 1.0) -> Self: """ Initialize any state needed for the benchmark. This will be called once before the benchmark is run. + + Args: + batch_size_factor: Combined hardware and user-requested scale factor + to apply to the base batch size. Each benchmark defines its own + base batch size and multiplies it by this factor. """ pass @@ -58,9 +63,9 @@ def _sync(cls) -> None: torch.cuda.synchronize() @classmethod - def run_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult: + def run_benchmark(cls, iters=10, warmup=1, batch_size_factor: float = 1.0) -> BenchmarkResult: null_timer = NullTimer() - benchmark = cls.new() + benchmark = cls.new(batch_size_factor=batch_size_factor) for _ in range(warmup): benchmark.run_instance(null_timer) timer = cls._make_timer() diff --git a/torch_harmonics/benchmark/disco.py b/torch_harmonics/benchmark/disco.py index cfd13101..dfe12ec5 100644 --- a/torch_harmonics/benchmark/disco.py +++ b/torch_harmonics/benchmark/disco.py @@ -22,7 +22,7 @@ def __init__(self, conv: DiscreteContinuousConvS2, x: torch.Tensor): @classmethod @abc.abstractmethod - def new(cls) -> "DiscreteContinuousConvS2Benchmark": ... + def new(cls, batch_size_factor: float = 1.0) -> "DiscreteContinuousConvS2Benchmark": ... @classmethod @final @@ -64,7 +64,7 @@ def run_instance(self, timer: Timer) -> TensorDict: class DiscreteContinuousConvS2TorchBenchmark1Degree(DiscreteContinuousConvS2Benchmark): @classmethod - def new(cls) -> "DiscreteContinuousConvS2TorchBenchmark1Degree": + def new(cls, batch_size_factor: float = 1.0) -> "DiscreteContinuousConvS2TorchBenchmark1Degree": return cls.new_with_shape( - B=scale_batch_size(4), in_channels=4, out_channels=4, nlat=180, nlon=360, + B=scale_batch_size(4, batch_size_factor), in_channels=4, out_channels=4, nlat=180, nlon=360, ) diff --git a/torch_harmonics/benchmark/hardware.py b/torch_harmonics/benchmark/hardware.py index 684974fb..66b36d77 100644 --- a/torch_harmonics/benchmark/hardware.py +++ b/torch_harmonics/benchmark/hardware.py @@ -56,9 +56,9 @@ def get_batch_size_factor() -> float: return factor -def scale_batch_size(base: int) -> int: - """Scale a base batch size (tuned for Tesla T4) by the hardware factor. +def scale_batch_size(base: int, factor: float) -> int: + """Scale a base batch size by the given factor. Always returns at least 1. """ - return max(1, round(base * get_batch_size_factor())) + return max(1, round(base * factor)) diff --git a/torch_harmonics/benchmark/run.py b/torch_harmonics/benchmark/run.py index 37dd1995..8caefd07 100644 --- a/torch_harmonics/benchmark/run.py +++ b/torch_harmonics/benchmark/run.py @@ -9,7 +9,7 @@ import torch from torch_harmonics.benchmark.benchmark import get_benchmarks -from torch_harmonics.benchmark.hardware import set_device +from torch_harmonics.benchmark.hardware import get_batch_size_factor, set_device _GIT_COMMIT: str | None = None @@ -54,13 +54,16 @@ def main( iters: int, output_dir: pathlib.Path, device: str, + batch_size_factors: list[float], ) -> int: set_device(device) output_dir.mkdir(parents=True, exist_ok=True) device_name = get_device_name() safe_device_name = device_name.replace(" ", "_").replace("/", "_").lower() + hardware_factor = get_batch_size_factor() logging.info(f"Running benchmarks on device: {device_name}") + logging.info(f"Hardware batch size factor: {hardware_factor}") benchmarks = get_benchmarks() if benchmark_name is not None: if benchmark_name not in benchmarks: @@ -73,22 +76,25 @@ def main( else: benchmarks_to_run = benchmarks - def get_filename(name, extension) -> pathlib.Path: + def get_filename(name, factor, extension) -> pathlib.Path: safe_name = name.replace("/", "_").replace(".", "_").lower() + factor_str = f"{factor:g}x" return ( output_dir - / f"{safe_name}_{safe_device_name}_{get_git_commit()}.{extension}" + / f"{safe_name}_{factor_str}_{safe_device_name}_{get_git_commit()}.{extension}" ) for name, cls in benchmarks_to_run.items(): - logging.info(f"Running benchmark: {name}") - result = cls.run_benchmark(iters=iters) - result_data = json.dumps(dataclasses.asdict(result), indent=2) - logging.info(f"Result:\n{result_data}") - json_path = get_filename(name, "json") - with open(json_path, "w") as f: - logging.info(f"Saving result json to {f.name}") - f.write(result_data) + for factor in batch_size_factors: + combined_factor = hardware_factor * factor + logging.info(f"Running benchmark: {name} (batch size factor: {factor:g}x)") + result = cls.run_benchmark(iters=iters, batch_size_factor=combined_factor) + result_data = json.dumps(dataclasses.asdict(result), indent=2) + logging.info(f"Result:\n{result_data}") + json_path = get_filename(name, factor, "json") + with open(json_path, "w") as f: + logging.info(f"Saving result json to {f.name}") + f.write(result_data) return 0 @@ -126,12 +132,23 @@ def cli() -> int: help="Device to run benchmarks on, e.g. 'cpu', 'cuda', 'cuda:1'. " "Defaults to 'cuda' if available, otherwise 'cpu'.", ) + parser.add_argument( + "--batch-size-factors", + type=float, + nargs="+", + default=[1, 2], + help="Batch size scale factors to run each benchmark with. " + "Each benchmark will be run once per factor. " + "These are multiplied with the hardware-specific factor. " + "Defaults to [1, 2].", + ) args = parser.parse_args() return main( benchmark_name=args.name, iters=args.iters, output_dir=pathlib.Path(args.output_dir), device=args.device, + batch_size_factors=args.batch_size_factors, ) diff --git a/torch_harmonics/benchmark/sht.py b/torch_harmonics/benchmark/sht.py index 59998b56..406f4673 100644 --- a/torch_harmonics/benchmark/sht.py +++ b/torch_harmonics/benchmark/sht.py @@ -22,7 +22,7 @@ def __init__(self, forward_sht: RealSHT, x: torch.Tensor): @classmethod @abc.abstractmethod - def new(cls) -> "RealSHTBenchmark": ... + def new(cls, batch_size_factor: float = 1.0) -> "RealSHTBenchmark": ... @classmethod @final @@ -48,15 +48,15 @@ def run_instance(self, timer: Timer) -> TensorDict: class RealSHTBenchmark1Degree(RealSHTBenchmark): @classmethod - def new(cls) -> "RealSHTBenchmark1Degree": - return cls.new_with_shape(B=scale_batch_size(4096), H=180, L=360) + def new(cls, batch_size_factor: float = 1.0) -> "RealSHTBenchmark1Degree": + return cls.new_with_shape(B=scale_batch_size(4096, batch_size_factor), H=180, L=360) @register_benchmark("real_sht_quarter_deg") class RealSHTBenchmarkQuarterDegree(RealSHTBenchmark): @classmethod - def new(cls) -> "RealSHTBenchmarkQuarterDegree": - return cls.new_with_shape(B=scale_batch_size(1), H=721, L=1440) + def new(cls, batch_size_factor: float = 1.0) -> "RealSHTBenchmarkQuarterDegree": + return cls.new_with_shape(B=scale_batch_size(1, batch_size_factor), H=721, L=1440) class InverseRealSHTBenchmark(BenchmarkABC): @@ -68,7 +68,7 @@ def __init__(self, inverse_sht: InverseRealSHT, x_hat: torch.Tensor): @classmethod @abc.abstractmethod - def new(cls) -> "InverseRealSHTBenchmark": ... + def new(cls, batch_size_factor: float = 1.0) -> "InverseRealSHTBenchmark": ... @classmethod @final @@ -97,12 +97,12 @@ def run_instance(self, timer: Timer) -> TensorDict: class InverseRealSHTBenchmark1Degree(InverseRealSHTBenchmark): @classmethod - def new(cls) -> "InverseRealSHTBenchmark1Degree": - return cls.new_with_shape(B=scale_batch_size(4096), H=180, L=360) + def new(cls, batch_size_factor: float = 1.0) -> "InverseRealSHTBenchmark1Degree": + return cls.new_with_shape(B=scale_batch_size(4096, batch_size_factor), H=180, L=360) @register_benchmark("inverse_real_sht_quarter_deg") class InverseRealSHTBenchmarkQuarterDegree(InverseRealSHTBenchmark): @classmethod - def new(cls) -> "InverseRealSHTBenchmarkQuarterDegree": - return cls.new_with_shape(B=scale_batch_size(1), H=721, L=1440) + def new(cls, batch_size_factor: float = 1.0) -> "InverseRealSHTBenchmarkQuarterDegree": + return cls.new_with_shape(B=scale_batch_size(1, batch_size_factor), H=721, L=1440)