From 854eecd1480be4a24412b9d13ab10d5fc54a5727 Mon Sep 17 00:00:00 2001
From: Anatoly Myachev <anatoly.myachev@intel.com>
Date: Wed, 15 Jan 2025 11:32:24 +0000
Subject: [PATCH 1/3] Test Mark changes

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
---
 python/triton/testing.py | 31 ++++++++++++++++++-------------
 1 file changed, 18 insertions(+), 13 deletions(-)

diff --git a/python/triton/testing.py b/python/triton/testing.py
index a2690cde62..1df746fc15 100644
--- a/python/triton/testing.py
+++ b/python/triton/testing.py
@@ -377,12 +377,14 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
         y_min = [f'{x}-min' for x in bench.line_names]
         y_max = [f'{x}-max' for x in bench.line_names]
         x_names = list(bench.x_names)
-        df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max)
-        for x in bench.x_vals:
+        if len(x_names) == 1:
+            index = pd.Index(bench.x_vals, name=x_names[0])
+        else:
             # x can be a single value or a sequence of values.
-            if not isinstance(x, (list, tuple)):
-                x = [x for _ in x_names]
-
+            x_vals = [tuple(x for _ in x_names) for x in bench.x_vals if not isinstance(x, (list, tuple))]
+            index = pd.MultiIndex.from_tuples(x_vals, names=x_names)
+        df = pd.DataFrame(index=index, columns=y_mean + y_min + y_max)
+        for x in df.index:
             if len(x) != len(x_names):
                 raise ValueError(f"Expected {len(x_names)} values, got {x}")
             x_args = dict(zip(x_names, x))
@@ -397,24 +399,28 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
                 row_mean += [y_mean]
                 row_min += [y_min]
                 row_max += [y_max]
-            df.loc[len(df)] = list(x) + row_mean + row_min + row_max
+            df.loc[x] = row_mean + row_min + row_max
 
         if bench.plot_name:
             plt.figure()
             ax = plt.subplot()
             # Plot first x value on x axis if there are multiple.
-            first_x = x_names[0]
+            index_name = x_names[0]
+            if len(x_names) == 1:
+                index = df.index
+            else:
+                index = df.index.get_level_values(0)
             for i, y in enumerate(bench.line_names):
                 y_min, y_max = df[y + '-min'], df[y + '-max']
                 col = bench.styles[i][0] if bench.styles else None
                 sty = bench.styles[i][1] if bench.styles else None
-                ax.plot(df[first_x], df[y], label=y, color=col, ls=sty)
+                ax.plot(index, df[y], label=y, color=col, ls=sty)
                 if not y_min.isnull().all() and not y_max.isnull().all():
                     y_min = y_min.astype(float)
                     y_max = y_max.astype(float)
-                    ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
+                    ax.fill_between(index, y_min, y_max, alpha=0.15, color=col)
             ax.legend()
-            ax.set_xlabel(bench.xlabel or first_x)
+            ax.set_xlabel(bench.xlabel or index_name)
             ax.set_ylabel(bench.ylabel)
             # ax.set_title(bench.plot_name)
             ax.set_xscale("log" if bench.x_log else "linear")
@@ -423,7 +429,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
                 plt.show()
             if save_path:
                 plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
-        df = df[x_names + bench.line_names]
+        df = df[bench.line_names]
         if diff_col and df.shape[1] == 2:
             col0, col1 = df.columns.tolist()
             df['Diff'] = df[col1] - df[col0]
@@ -432,8 +438,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
             print(bench.plot_name + ':')
             print(df.to_string())
         if save_path:
-            df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
-                      index=False)
+            df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f")
         return df
 
     def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):

From a890e3a9cbb3170ec0bc37eb865627a3ae3618c6 Mon Sep 17 00:00:00 2001
From: Anatoly Myachev <anatoly.myachev@intel.com>
Date: Wed, 15 Jan 2025 12:42:01 +0000
Subject: [PATCH 2/3] fix

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
---
 python/triton/testing.py | 15 ++++-----------
 1 file changed, 4 insertions(+), 11 deletions(-)

diff --git a/python/triton/testing.py b/python/triton/testing.py
index 1df746fc15..d9fa078438 100644
--- a/python/triton/testing.py
+++ b/python/triton/testing.py
@@ -377,13 +377,9 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
         y_min = [f'{x}-min' for x in bench.line_names]
         y_max = [f'{x}-max' for x in bench.line_names]
         x_names = list(bench.x_names)
-        if len(x_names) == 1:
-            index = pd.Index(bench.x_vals, name=x_names[0])
-        else:
-            # x can be a single value or a sequence of values.
-            x_vals = [tuple(x for _ in x_names) for x in bench.x_vals if not isinstance(x, (list, tuple))]
-            index = pd.MultiIndex.from_tuples(x_vals, names=x_names)
-        df = pd.DataFrame(index=index, columns=y_mean + y_min + y_max)
+        x_vals = [tuple(x for _ in x_names) if not isinstance(x, (list, tuple)) else x for x in bench.x_vals]
+        index = pd.Index(x_vals, name=tuple(x_names))
+        df = pd.DataFrame(index=index, columns=y_mean + y_min + y_max, dtype="float")
         for x in df.index:
             if len(x) != len(x_names):
                 raise ValueError(f"Expected {len(x_names)} values, got {x}")
@@ -406,10 +402,7 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
             ax = plt.subplot()
             # Plot first x value on x axis if there are multiple.
             index_name = x_names[0]
-            if len(x_names) == 1:
-                index = df.index
-            else:
-                index = df.index.get_level_values(0)
+            index = df.index.get_level_values(0)
             for i, y in enumerate(bench.line_names):
                 y_min, y_max = df[y + '-min'], df[y + '-max']
                 col = bench.styles[i][0] if bench.styles else None

From e1834c788f82d5d38bf7b6d9a1e4fca8eff2ce6b Mon Sep 17 00:00:00 2001
From: Anatoly Myachev <anatoly.myachev@intel.com>
Date: Wed, 15 Jan 2025 18:50:44 +0000
Subject: [PATCH 3/3] multiindex columns

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
---
 python/triton/testing.py | 29 +++++++++++++----------------
 1 file changed, 13 insertions(+), 16 deletions(-)

diff --git a/python/triton/testing.py b/python/triton/testing.py
index d9fa078438..da3bc28192 100644
--- a/python/triton/testing.py
+++ b/python/triton/testing.py
@@ -369,33 +369,30 @@ def __init__(self, fn, benchmarks):
 
     def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
              save_precision=6, **kwrags):
-        import os
-
         import matplotlib.pyplot as plt
         import pandas as pd
-        y_mean = bench.line_names
-        y_min = [f'{x}-min' for x in bench.line_names]
-        y_max = [f'{x}-max' for x in bench.line_names]
+
         x_names = list(bench.x_names)
         x_vals = [tuple(x for _ in x_names) if not isinstance(x, (list, tuple)) else x for x in bench.x_vals]
         index = pd.Index(x_vals, name=tuple(x_names))
-        df = pd.DataFrame(index=index, columns=y_mean + y_min + y_max, dtype="float")
+
+        columns = pd.MultiIndex.from_product(([bench.ylabel], bench.line_names, ["mean", "min", "max"]),
+                                             names=("unit", "provider", "method"))
+        df = pd.DataFrame(index=index, columns=columns, dtype="float")
         for x in df.index:
             if len(x) != len(x_names):
                 raise ValueError(f"Expected {len(x_names)} values, got {x}")
             x_args = dict(zip(x_names, x))
 
-            row_mean, row_min, row_max = [], [], []
-            for y in bench.line_vals:
+            for idx, y in enumerate(bench.line_vals):
                 ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
                 try:
                     y_mean, y_min, y_max = ret
                 except TypeError:
                     y_mean, y_min, y_max = ret, None, None
-                row_mean += [y_mean]
-                row_min += [y_min]
-                row_max += [y_max]
-            df.loc[x] = row_mean + row_min + row_max
+                df.at[x, (bench.ylabel, bench.line_names[idx], "mean")] = y_mean
+                df.at[x, (bench.ylabel, bench.line_names[idx], "min")] = y_min
+                df.at[x, (bench.ylabel, bench.line_names[idx], "max")] = y_max
 
         if bench.plot_name:
             plt.figure()
@@ -404,10 +401,10 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
             index_name = x_names[0]
             index = df.index.get_level_values(0)
             for i, y in enumerate(bench.line_names):
-                y_min, y_max = df[y + '-min'], df[y + '-max']
+                y_min, y_max = df[(bench.ylabel, y, "min")], df[(bench.ylabel, y, "max")]
                 col = bench.styles[i][0] if bench.styles else None
                 sty = bench.styles[i][1] if bench.styles else None
-                ax.plot(index, df[y], label=y, color=col, ls=sty)
+                ax.plot(index, df[(bench.ylabel, y, "mean")], label=y, color=col, ls=sty)
                 if not y_min.isnull().all() and not y_max.isnull().all():
                     y_min = y_min.astype(float)
                     y_max = y_max.astype(float)
@@ -422,10 +419,10 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
                 plt.show()
             if save_path:
                 plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
-        df = df[bench.line_names]
+        df = df.loc[:, (slice(None), slice(None), "mean")]
         if diff_col and df.shape[1] == 2:
             col0, col1 = df.columns.tolist()
-            df['Diff'] = df[col1] - df[col0]
+            df[(bench.ylabel, '', 'Diff')] = df[col1] - df[col0]
 
         if print_data:
             print(bench.plot_name + ':')