-
Notifications
You must be signed in to change notification settings - Fork 65
Introduce benchmark framework using CUDA events #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mcgibbon
wants to merge
8
commits into
NVIDIA:main
Choose a base branch
from
mcgibbon:feature/benchmark-framework
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
380ebd9
Add benchmark framework with registry pattern and SHT benchmarks
mcgibbon ad511f1
Add DiscreteContinuousConvS2 benchmark using torch sparse path
mcgibbon 7247eb7
Add hardware-dependent batch size scaling for benchmarks
mcgibbon 71eaefd
adding batch size and device ovveride to CLA
azrael417 eedbbac
Revert separate forward/backward benchmarks; use child timers instead
mcgibbon fe21781
Add context manager support to CPUEventPair and use it in run_benchmark
mcgibbon 04ef636
Remove batch_size override feature from benchmark CLI
mcgibbon 5036e6c
Add batch size factor support for scaling benchmarks
mcgibbon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| *.DS_Store | ||
| __pycache__ | ||
| *.so | ||
| checkpoints | ||
| checkpoints | ||
| *benchmark_results |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| from torch_harmonics.benchmark.benchmark import ( | ||
| BenchmarkABC, | ||
| BenchmarkResult, | ||
| get_benchmarks, | ||
| register_benchmark, | ||
| ) | ||
| from torch_harmonics.benchmark.timer import ( | ||
| CPUTimer, | ||
| CUDATimer, | ||
| NullTimer, | ||
| Timer, | ||
| TimerResult, | ||
| ) | ||
|
|
||
| # Import to trigger registration of built-in benchmarks. | ||
| import torch_harmonics.benchmark.sht # noqa: F401 | ||
| import torch_harmonics.benchmark.disco # noqa: F401 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| import sys | ||
|
|
||
| import torch_harmonics.benchmark # noqa: F401 — triggers benchmark registration | ||
| from torch_harmonics.benchmark.run import cli | ||
|
|
||
| sys.exit(cli()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| import abc | ||
| import dataclasses | ||
| from collections.abc import Callable | ||
| from typing import Self | ||
|
|
||
| import torch | ||
|
|
||
| from torch_harmonics.benchmark.hardware import get_device | ||
| from torch_harmonics.benchmark.timer import ( | ||
| CPUEventPair, | ||
| CPUTimer, | ||
| CUDATimer, | ||
| NullTimer, | ||
| Timer, | ||
| TimerResult, | ||
| ) | ||
|
|
||
| TensorDict = dict[str, torch.Tensor] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class BenchmarkResult: | ||
| device: str | ||
| timer: TimerResult | ||
| cpu_time: float | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"BenchmarkResult(device={self.device}, timer={self.timer}, cpu_time={self.cpu_time})" | ||
|
|
||
| def asdict(self) -> dict: | ||
| return dataclasses.asdict(self) | ||
|
|
||
| def get_logs(self, max_depth: int) -> dict[str, float]: | ||
| logs = {"device": self.device, "cpu_time": self.cpu_time} | ||
| logs.update(self.timer.get_logs(max_depth=max_depth)) | ||
| return logs | ||
|
|
||
|
|
||
| class BenchmarkABC(abc.ABC): | ||
| @classmethod | ||
| @abc.abstractmethod | ||
| def new(cls: type[Self], batch_size_factor: float = 1.0) -> Self: | ||
| """ | ||
| Initialize any state needed for the benchmark. | ||
| This will be called once before the benchmark is run. | ||
|
|
||
| Args: | ||
| batch_size_factor: Combined hardware and user-requested scale factor | ||
| to apply to the base batch size. Each benchmark defines its own | ||
| base batch size and multiplies it by this factor. | ||
| """ | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _make_timer(cls) -> CUDATimer | CPUTimer: | ||
| if torch.cuda.is_available(): | ||
| return CUDATimer() | ||
| return CPUTimer() | ||
|
|
||
| @classmethod | ||
| def _sync(cls) -> None: | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
|
|
||
| @classmethod | ||
| def run_benchmark(cls, iters=10, warmup=1, batch_size_factor: float = 1.0) -> BenchmarkResult: | ||
| null_timer = NullTimer() | ||
| benchmark = cls.new(batch_size_factor=batch_size_factor) | ||
| for _ in range(warmup): | ||
| benchmark.run_instance(null_timer) | ||
| timer = cls._make_timer() | ||
| with CPUEventPair() as cpu_timer: | ||
| for _ in range(iters): | ||
| with timer: | ||
| benchmark.run_instance(timer) | ||
| cls._sync() | ||
| return BenchmarkResult( | ||
| device=str(get_device()), | ||
| timer=timer.result, | ||
| cpu_time=cpu_timer.elapsed_time_ms(), | ||
| ) | ||
|
|
||
| @abc.abstractmethod | ||
| def run_instance(self: Self, timer: Timer) -> TensorDict: | ||
| """ | ||
| Run the benchmark. This will be called multiple times, | ||
| and should return a TensorDict of results. | ||
|
|
||
| Use timer.child("forward") and timer.child("backward") context | ||
| managers to separately time phases within the benchmark. | ||
|
|
||
| This must not mutate any state on self, since the same instance may be | ||
| used across multiple iterations. | ||
| """ | ||
| pass | ||
|
|
||
|
|
||
| _BENCHMARKS: dict[str, type[BenchmarkABC]] = {} | ||
|
|
||
|
|
||
| def register_benchmark(name: str) -> Callable[[type[BenchmarkABC]], type[BenchmarkABC]]: | ||
| def _register(fn: type[BenchmarkABC]) -> type[BenchmarkABC]: | ||
| if name in _BENCHMARKS: | ||
| raise ValueError(f"Benchmark with name '{name}' is already registered.") | ||
| _BENCHMARKS[name] = fn | ||
| return fn | ||
|
|
||
| return _register | ||
|
|
||
|
|
||
| def get_benchmarks() -> dict[str, type[BenchmarkABC]]: | ||
| return _BENCHMARKS.copy() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import abc | ||
| from typing import Self, final | ||
|
|
||
| import torch | ||
|
|
||
| from torch_harmonics.benchmark.benchmark import ( | ||
| BenchmarkABC, | ||
| TensorDict, | ||
| register_benchmark, | ||
| ) | ||
| from torch_harmonics.benchmark.hardware import get_device, scale_batch_size | ||
| from torch_harmonics.benchmark.timer import Timer | ||
| from torch_harmonics.disco import DiscreteContinuousConvS2 | ||
|
|
||
|
|
||
| class DiscreteContinuousConvS2Benchmark(BenchmarkABC): | ||
|
|
||
| @final | ||
| def __init__(self, conv: DiscreteContinuousConvS2, x: torch.Tensor): | ||
| self.conv = conv | ||
| self.x = x | ||
|
|
||
| @classmethod | ||
| @abc.abstractmethod | ||
| def new(cls, batch_size_factor: float = 1.0) -> "DiscreteContinuousConvS2Benchmark": ... | ||
|
|
||
| @classmethod | ||
| @final | ||
| def new_with_shape( | ||
| cls: type[Self], | ||
| B: int, | ||
| in_channels: int, | ||
| out_channels: int, | ||
| nlat: int, | ||
| nlon: int, | ||
| kernel_shape: int = 3, | ||
| ) -> Self: | ||
| device = get_device() | ||
| conv = DiscreteContinuousConvS2( | ||
| in_channels=in_channels, | ||
| out_channels=out_channels, | ||
| in_shape=(nlat, nlon), | ||
| out_shape=(nlat, nlon), | ||
| kernel_shape=kernel_shape, | ||
| theta_cutoff=None, | ||
| optimized_kernel=False, | ||
| ).to(device) | ||
| x = torch.randn(B, in_channels, nlat, nlon, dtype=torch.float32, device=device, requires_grad=True) | ||
| return cls(conv=conv, x=x) | ||
|
|
||
| @final | ||
| def run_instance(self, timer: Timer) -> TensorDict: | ||
| if self.x.grad is not None: | ||
| self.x.grad.zero_() | ||
| with timer.child("forward"): | ||
| result = self.conv(self.x) | ||
| with timer.child("backward"): | ||
| g = torch.randn_like(result) | ||
| result.backward(g) | ||
| return {"output": result.detach()} | ||
|
|
||
|
|
||
| @register_benchmark("disco_conv_s2_torch_1deg") | ||
| class DiscreteContinuousConvS2TorchBenchmark1Degree(DiscreteContinuousConvS2Benchmark): | ||
|
|
||
| @classmethod | ||
| def new(cls, batch_size_factor: float = 1.0) -> "DiscreteContinuousConvS2TorchBenchmark1Degree": | ||
| return cls.new_with_shape( | ||
| B=scale_batch_size(4, batch_size_factor), in_channels=4, out_channels=4, nlat=180, nlon=360, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Batch size scale factors relative to Tesla T4 (the default baseline). | ||
| # To add a new GPU, add an entry mapping its device name (as returned by | ||
| # torch.cuda.get_device_properties(...).name) to a float scale factor. | ||
| # Values > 1.0 mean the GPU is faster than a T4 and can use larger batches; | ||
| # values < 1.0 mean it is slower. | ||
| _BATCH_SIZE_FACTORS: dict[str, float] = { | ||
| "Tesla T4": 1.0, | ||
| } | ||
|
|
||
| _DEFAULT_BATCH_SIZE_FACTOR = 1.0 | ||
|
|
||
| _device: torch.device | None = None | ||
|
|
||
|
|
||
| def set_device(device: str | torch.device) -> None: | ||
| """Override the device used by all benchmarks.""" | ||
| global _device | ||
| _device = torch.device(device) | ||
|
|
||
|
|
||
| def get_device() -> torch.device: | ||
| if _device is not None: | ||
| return _device | ||
| if torch.cuda.is_available(): | ||
| return torch.device("cuda", torch.cuda.current_device()) | ||
| return torch.device("cpu") | ||
|
|
||
|
|
||
| def get_batch_size_factor() -> float: | ||
| """Return a hardware-dependent scale factor for benchmark batch sizes. | ||
|
|
||
| Benchmarks define a base batch size tuned for a Tesla T4. This function | ||
| returns a multiplier so that benchmarks take a similar wall-clock time | ||
| on other hardware. If the batch size is too small, the GPU will not be fully | ||
| occupied, and the benchmarks cannot be used to tune performance. | ||
|
|
||
| Unknown devices fall back to the T4 default (1.0). | ||
| """ | ||
| if not torch.cuda.is_available(): | ||
| return _DEFAULT_BATCH_SIZE_FACTOR | ||
| name = torch.cuda.get_device_properties(torch.cuda.current_device()).name | ||
| factor = _BATCH_SIZE_FACTORS.get(name) | ||
| if factor is None: | ||
| logger.warning( | ||
| f"Unknown GPU '{name}', using default batch size factor " | ||
| f"{_DEFAULT_BATCH_SIZE_FACTOR}. Add an entry to " | ||
| f"_BATCH_SIZE_FACTORS in hardware.py to tune for this device." | ||
| ) | ||
| return _DEFAULT_BATCH_SIZE_FACTOR | ||
| return factor | ||
|
|
||
|
|
||
| def scale_batch_size(base: int, factor: float) -> int: | ||
| """Scale a base batch size by the given factor. | ||
|
|
||
| Always returns at least 1. | ||
| """ | ||
| return max(1, round(base * factor)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would also need to be modified to take in an optional batch_size. I'm wondering now though what the use case is for the batch size needing to be changed in the CLI rather than changing the benchmark. Overriding the batch size breaks the promise the output file makes that it's giving the timings for "this benchmark". e.g. if a user runs two executions on different git shas with different batch size arguments, the timings could be different even if the code is the same, and there's no way in the output files to tell whether this is because of random machine noise or because two different args were passed. This situation is a lot worse when some of the code did change.
Clearly batch size needs to be tuned on different hardware, but that's why the hardware-specific scaling exists, and the hardware is included in the benchmark filename.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree with making out GPU utilization because in a realistic case you do not do that memory wise. Therefore, it is good to understand how a kernel scales with batch size. Having a new config entry just for a different batch size is very cumbersome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this sentence, but the next one makes sense and I think might be saying the same thing.
That's a great use case, thanks for explaining it. I think it would make sense for the benchmark to automatically run with more than one batch size, perhaps a base value and then double that value. I'll update the code to do that, in a way that the framework does it for each benchmark.
For now I'll have the benchmarks take in "batch size factor" argument to the init function, and have this included in the filename. By default the benchmarks will run with 1x and 2x factors, but these factors will be configurable as a list from the command line.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I think you meant "maxing out". You could use this batch size factor feature, or another option is to use benchmarks that explicitly are tuned not to maximize gpu utilization. For the batch size factor, perhaps we'd retune the base case to be a low utilization case, and then use two factors that will each be in the maximized regime.
I would suggest not using these isolated benchmarks as a way to measure "realistic" performance, in the sense of what you'll get when running FCN3 on top-end hardware. I don't think they can accomplish this well. The purpose of these isolated benchmarks is to, well, isolate the code. At sufficiently small problem sizes, overhead dominates and all versions of the code run in the same amount of time, in a way that also doesn't reflect realistic use.
When I say I'm maxing out GPU occupancy, I just mean that I'm increasing the problem size until (diff_run_time/diff_problem_size) asymptotes, though I'm doing this manually and not very well - you could likely do better. This by far does not mean maximum memory utilization. On my T4 I'm kind of targeting 30ms+ run-times. I find when the run times are 1-7ms, the execution time is quite insensitive to significant changes in memory ordering.
You can write a benchmark that uses child timers to fully map out timings within for example FCN3, and that code can tell you how much time is spent in each section in that realistic case. We've done this in our SFNO, where the block takes in a timer argument. I'm considering refactoring this to use our GlobalTimer singleton class instead so it doesn't show up in method signatures. Some leakage does occur when memory ordering changes, e.g. you may get one step faster by making a change that delays a contiguous call into a later block. But this is better for including also the impact different kernels have on one another during execution (e.g. these changes in memory ordering do affect the execution time).