Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] rewrite do_bench with fixed warmup/rep runs mode #411

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 97 additions & 2 deletions benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pytest
import torch
import triton
import yaml

import flag_gems
Expand Down Expand Up @@ -55,6 +54,102 @@ def SkipVersion(module_name, skip_pattern):
return (major, minor) > (M, N)


def triton_do_bench_rewritten(
fn,
warmup=25,
rep=100,
grad_to_none=None,
quantiles=None,
fast_flush=True,
return_mode="mean",
device_type="cuda",
fixed_warmup_rep_runs=True,
):
"""
This is a rewritten version of the original `triton.testing.do_bench` function.

Benchmark the runtime of the provided function. By default, return the median runtime
of :code:`fn` along with the 20-th and 80-th performance percentile.

This function supports two modes for determining the number of warmup and repetition
runs, by appending a parameter called `fixed_warmup_rep_runs`:
1. Dynamic Mode (the original implementation of `triton.testing.do_bench`):
Estimates the runtime of the kernel and dynamically adjusts the number of warmup and
repetition runs based on the provided `warmup` and `rep` times (in milliseconds).
2. Fixed Mode (default in this rewritten version, and consistent with torch's testing):
Uses the provided `warmup` and `rep` values directly as the number of warmup and
repetition runs.

Please refer to the original implementation of `triton.testing.do_bench` function for
more details:
https://github.com/triton-lang/triton/blob/199fd8a239068318e94d39843c4c676f44883bd3/python/triton/testing.py#L162
"""

assert return_mode in ["min", "max", "mean", "median"]

di = torch._dynamo.device_interface.get_interface_for_device(device_type)

fn()
di.synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type)
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type)

if not fixed_warmup_rep_runs:
# Estimate the runtime of the function
start_event = di.Event(enable_timing=True)
end_event = di.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
di.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
else:
n_warmup = warmup
n_repeat = rep

start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
di.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float
)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()


class Benchmark:
device: str = device
DEFAULT_METRICS = DEFAULT_METRICS
Expand Down Expand Up @@ -247,7 +342,7 @@ def get_latency(self, op, *args, **kwargs):
end = time.time()
latency = (end - start) / Config.repetition * 1000
else:
latency = triton.testing.do_bench(
latency = triton_do_bench_rewritten(
fn,
warmup=Config.warm_up,
rep=Config.repetition,
Expand Down
Loading