Skip to content

Introduce benchmark framework using CUDA events#157

Open
mcgibbon wants to merge 5 commits intoNVIDIA:mainfrom
mcgibbon:feature/benchmark-framework
Open

Introduce benchmark framework using CUDA events#157
mcgibbon wants to merge 5 commits intoNVIDIA:mainfrom
mcgibbon:feature/benchmark-framework

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Mar 12, 2026

This PR adds timing for the SHT and for the torch implementation of DISCO convolution through a new benchmarking framework, run through python -m torch_harmonics.benchmark.

This is largely taken from the implementation we used/I authored in https://github.com/ai2cm/ace

mcgibbon and others added 3 commits March 12, 2026 15:30
Introduce a torch_harmonics.benchmark subpackage with:
- Timer infrastructure (CUDATimer, NullTimer, CPUEventPair) for GPU
  event-based and CPU wall-clock timing
- BenchmarkABC base class with registry via @register_benchmark
- CLI runner (python -m torch_harmonics.benchmark) that saves JSON results
- RealSHT and InverseRealSHT benchmarks at 1-degree resolution

Also add benchmark_results to .gitignore.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Register a disco_conv_s2_torch_1deg benchmark at 1-degree resolution
(B=4, 4 channels, 180x360) using the non-optimized torch contraction
path, which does not require the custom CUDA extension.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduce hardware.py with a device-name-to-scale-factor lookup table
so benchmark batch sizes adapt to different GPUs. Base batch sizes are
tuned for Tesla T4 (factor 1.0). Unknown devices default to 1.0 with
a warning to add an entry for their hardware.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@mcgibbon
Copy link
Contributor Author

@azrael417 you may find this helpful to check the SHT timings on your hardware, for #155. You'll want to insert new batch size scaling factors to fully occupy the hardware. I tried to make it straightforward to add new benchmarks.

The entrypoint will create git-tag labelled json files under benchmark_results/ in the directory you run it from (location modifiable by flag).

@azrael417
Copy link
Collaborator

Hello Jeremy, thanks for putting this together.
I have added multiple things to the MR. This is what I added:

  • backward benchmark
  • device selection support
  • batch size override
    Can you please have a look and see if you are OK with it?

@mcgibbon
Copy link
Contributor Author

Hello Jeremy, thanks for putting this together. I have added multiple things to the MR. This is what I added:

  • backward benchmark
  • device selection support
  • batch size override
    Can you please have a look and see if you are OK with it?

Thanks @azrael417 . I don't see these commits in the history on this PR, can you link me to where I can check them out?

@azrael417
Copy link
Collaborator

azrael417 commented Mar 20, 2026

Oh, I seemingly cannot push them automatically to your branch. Those ended up in branch pr-157. Feel free to review and potentially merge those into this PR.

torch.cuda.synchronize()

@classmethod
def run_forward_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should separate out forward and backward benchmarks like this. Rather, run_instance is free to define a timer.context("forward") and timer.context("backward") block as separate blocks if it so chooses, without being required to do so. That way the backward benchmark can also take advantage of the work from the forward benchmark, instead of repeating it.

I'll refactor the existing benchmarks so they time the backward pass, and remove the "backward" framework infrastructure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this so that you can run forward and backward independently. For example when we implement a new kernel we first implement and optimize the forward pass. In this case, there is no backward defined and we do not want to run that.

Replace run_instance_forward/run_instance_backward with a single
run_instance that uses timer.child("forward") and timer.child("backward")
to time phases within one call. This lets the backward pass reuse the
forward computation and keeps the benchmark API simpler. Also fix
gradient accumulation across iterations and remove unnecessary
retain_graph=True.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 23, 2026

What is the use case for the batch size override? I was hoping the GPU-dependent factors would handle this, and was thinking the benchmark code should set its problem size.

I'm a little worried that CLI-set batch sizes will result in different benchmark runs/output files using different batch sizes, which doesn't show up in the filename or in the result. That means we can no longer be confident the output directory contains directly comparable benchmarks. At least when batch sizes get changed by modifying the benchmark code, this is reflected in the git sha (or the -dirty suffix) changing.

Comment on lines +69 to +71
If a global override has been set via set_batch_size(), that value is
returned directly. Otherwise the base is scaled by the hardware factor
(tuned relative to a Tesla T4). Always returns at least 1.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit counter-intuitive, I wouldn't expect a helper function scale_batch_size to access globals or do this kind of behavior. We should override at a higher level in the code where it's more appropriate.

Comment on lines +67 to +68
cpu_timer = CPUEventPair()
cpu_timer.record_start()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should use with cpu_timer: instead of record_start/record_end.

) -> int:
set_device(device)
if batch_size is not None:
set_batch_size(batch_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid a lot of indirection if we explicitly pass cls.run_benchmark(iters=iters, batch_size=batch_size) with a default batch_size=None on that method. Let's refactor to do that, and delete this global state.

)

@abc.abstractmethod
def run_instance(self: Self, timer: Timer) -> TensorDict:
Copy link
Contributor Author

@mcgibbon mcgibbon Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also need to be modified to take in an optional batch_size. I'm wondering now though what the use case is for the batch size needing to be changed in the CLI rather than changing the benchmark. Overriding the batch size breaks the promise the output file makes that it's giving the timings for "this benchmark". e.g. if a user runs two executions on different git shas with different batch size arguments, the timings could be different even if the code is the same, and there's no way in the output files to tell whether this is because of random machine noise or because two different args were passed. This situation is a lot worse when some of the code did change.

Clearly batch size needs to be tuned on different hardware, but that's why the hardware-specific scaling exists, and the hardware is included in the benchmark filename.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with making out GPU utilization because in a realistic case you do not do that memory wise. Therefore, it is good to understand how a kernel scales with batch size. Having a new config entry just for a different batch size is very cumbersome.

@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 23, 2026

For now I'm going to remove the batch_size override from this PR, because of the concerns I have about it. It breaks the main purpose of this code, which is to compare performance across different git sha as the code evolves. But if there's a specific use case you need it for, I can add it back in or revert the commit removing it.

Ready for another look @azrael417 . I will be out for a little over 2 weeks starting this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants