diff --git a/benchmark/performance_utils.py b/benchmark/performance_utils.py index b30243f23..6b7c7c9da 100644 --- a/benchmark/performance_utils.py +++ b/benchmark/performance_utils.py @@ -6,7 +6,6 @@ import pytest import torch -import triton import yaml import flag_gems @@ -55,6 +54,102 @@ def SkipVersion(module_name, skip_pattern): return (major, minor) > (M, N) +def triton_do_bench_rewritten( + fn, + warmup=25, + rep=100, + grad_to_none=None, + quantiles=None, + fast_flush=True, + return_mode="mean", + device_type="cuda", + fixed_warmup_rep_runs=True, +): + """ + This is a rewritten version of the original `triton.testing.do_bench` function. + + Benchmark the runtime of the provided function. By default, return the median runtime + of :code:`fn` along with the 20-th and 80-th performance percentile. + + This function supports two modes for determining the number of warmup and repetition + runs, by appending a parameter called `fixed_warmup_rep_runs`: + 1. Dynamic Mode (the original implementation of `triton.testing.do_bench`): + Estimates the runtime of the kernel and dynamically adjusts the number of warmup and + repetition runs based on the provided `warmup` and `rep` times (in milliseconds). + 2. Fixed Mode (default in this rewritten version, and consistent with torch's testing): + Uses the provided `warmup` and `rep` values directly as the number of warmup and + repetition runs. + + Please refer to the original implementation of `triton.testing.do_bench` function for + more details: + https://github.com/triton-lang/triton/blob/199fd8a239068318e94d39843c4c676f44883bd3/python/triton/testing.py#L162 + """ + + assert return_mode in ["min", "max", "mean", "median"] + + di = torch._dynamo.device_interface.get_interface_for_device(device_type) + + fn() + di.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type) + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type) + + if not fixed_warmup_rep_runs: + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + else: + n_warmup = warmup + n_repeat = rep + + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float + ) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + class Benchmark: device: str = device DEFAULT_METRICS = DEFAULT_METRICS @@ -247,7 +342,7 @@ def get_latency(self, op, *args, **kwargs): end = time.time() latency = (end - start) / Config.repetition * 1000 else: - latency = triton.testing.do_bench( + latency = triton_do_bench_rewritten( fn, warmup=Config.warm_up, rep=Config.repetition,