Skip to content

Commit 803677e

Browse files
authored
Merge branch 'master' into master
2 parents 7d9f603 + 18d4127 commit 803677e

File tree

374 files changed

+51749
-396
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

374 files changed

+51749
-396
lines changed

.gitattributes

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
src/flag_gems/runtime/backend/_iluvatar/**/* diff=nodiff
2+
src/flag_gems/runtime/backend/_metax/**/* diff=nodiff

benchmark/attri_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class BenchmarkMetrics:
7272
tflops: Optional[float] = None
7373
# Utilization (not implemented yet)
7474
utilization: Optional[float] = None
75+
# Speedup compared to base data
76+
compared_speedup: Optional[float] = None
7577
# Error message
7678
error_msg: Optional[str] = None
7779

benchmark/conftest.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222

2323
device = flag_gems.device
24+
vendor_name = flag_gems.vendor_name
2425

2526

2627
class BenchConfig:
@@ -29,6 +30,11 @@ def __init__(self):
2930
self.bench_level = BenchLevel.COMPREHENSIVE
3031
self.warm_up = DEFAULT_WARMUP_COUNT
3132
self.repetition = DEFAULT_ITER_COUNT
33+
if (
34+
vendor_name == "kunlunxin"
35+
): # Speed Up Benchmark Test, Big Shape Will Cause Timeout
36+
self.warm_up = 1
37+
self.repetition = 1
3238
self.record_log = False
3339
self.user_desired_dtypes = None
3440
self.user_desired_metrics = None
@@ -41,7 +47,9 @@ def __init__(self):
4147

4248
def pytest_addoption(parser):
4349
parser.addoption(
44-
"--mode",
50+
"--mode"
51+
if vendor_name != "kunlunxin"
52+
else "--fg_mode", # TODO: fix pytest-* common --mode args
4553
action="store",
4654
default=device,
4755
required=False,

benchmark/core_shapes.yaml

+27
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,30 @@ AttentionBenchmark:
192192
- [4, 8, 2048, 128]
193193
- [4, 8, 3072, 128]
194194
- [4, 8, 4096, 128]
195+
196+
KronBenchmark:
197+
shapes:
198+
- [16,16]
199+
- [64,64]
200+
- [128,128]
201+
- [256,256]
202+
- [4, 8, 16, 32]
203+
- [4, 8, 32, 32]
204+
- [4, 8, 64, 32]
205+
- [4, 8, 128, 32]
206+
207+
IndexPutAccFalseBenchmark:
208+
shapes:
209+
- [[268435456,], [[65536,],], [65536,]]
210+
- [[32, 32], [[8,], [2, 8]], [8,]]
211+
- [[1024, 1024], [[4, 64],], [1024,]]
212+
- [[512, 512, 512], [[2, 128], [128,], [128,]], [128,]]
213+
- [[512, 512, 512], [[2, 128],], [512,]]
214+
215+
IndexPutAccTrueBenchmark:
216+
shapes:
217+
- [[268435456,], [[65536,],], [65536,]]
218+
- [[32, 32], [[8,], [8,]], [8,]]
219+
- [[1024, 1024], [[64,], [64,]], [64,]]
220+
- [[512, 512, 512], [[128,], [128,], [128,]], [128,]]
221+
- [[512, 512, 512], [[2, 128], [2, 128], [2, 128]], [2, 128]]

benchmark/performance_utils.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import importlib
33
import logging
4+
import os
45
import time
56
from typing import Any, Generator, List, Optional, Tuple
67

@@ -28,7 +29,11 @@
2829
torch_backend_device = flag_gems.runtime.torch_backend_device
2930
torch_device_fn = flag_gems.runtime.torch_device_fn
3031
device = flag_gems.device
31-
torch_backend_device.matmul.allow_tf32 = False
32+
vendor_name = flag_gems.vendor_name
33+
if device == "musa":
34+
torch.backends.mudnn.allow_tf32 = False
35+
else:
36+
torch_backend_device.matmul.allow_tf32 = False
3237

3338

3439
def SkipVersion(module_name, skip_pattern):
@@ -225,6 +230,11 @@ def init_user_config(self):
225230
self.cpu_mode = Config.cpu_mode
226231
self.set_dtypes(Config.user_desired_dtypes)
227232
self.set_metrics(Config.user_desired_metrics)
233+
if vendor_name == "kunlunxin":
234+
Config.shape_file = os.path.join(
235+
os.path.dirname(__file__),
236+
"../src/flag_gems/runtime/backend/_kunlunxin/core_shapes.yaml",
237+
) # Speed Up Benchmark Test, Big Shape Will Cause Timeout
228238
self.set_shapes(Config.shape_file)
229239

230240
def set_gems(self, gems_op):
@@ -247,7 +257,12 @@ def get_latency(self, op, *args, **kwargs):
247257
end = time.time()
248258
latency = (end - start) / Config.repetition * 1000
249259
else:
250-
latency = triton.testing.do_bench(
260+
do_bench = (
261+
triton.musa_testing.do_bench
262+
if device == "musa"
263+
else triton.testing.do_bench
264+
)
265+
latency = do_bench(
251266
fn,
252267
warmup=Config.warm_up,
253268
rep=Config.repetition,
@@ -457,10 +472,10 @@ def generate_tensor_input(shape, dtype, device):
457472
torch.iinfo(dtype).max,
458473
shape,
459474
dtype=dtype,
460-
device=device,
461-
)
475+
device="cpu",
476+
).to(device)
462477
elif dtype in BOOL_DTYPES:
463-
return torch.randint(0, 2, size=shape, dtype=dtype, device=device)
478+
return torch.randint(0, 2, size=shape, dtype=dtype, device="cpu").to(device)
464479

465480

466481
def binary_input_fn(shape, cur_dtype, device):

benchmark/summary_for_plot.py

+164-20
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@
3232
import json
3333
from collections import defaultdict
3434
from dataclasses import dataclass
35-
from typing import List
35+
from typing import Any, Dict, List
3636

3737
from attri_util import BenchmarkMetrics, BenchmarkResult
3838

39+
# to enable log files crossing speedup calculation
40+
ENABLE_COMPARE = False
41+
3942

4043
@dataclass
4144
class SummaryResultOverDtype:
@@ -47,20 +50,50 @@ class SummaryResultOverDtype:
4750
int32_speedup: float = 0.0
4851
bool_speedup: float = 0.0
4952
cfloat_speedup: float = 0.0
53+
54+
# to calculate the speedup across log files.
55+
compared_float16_speedup: float = 0.0
56+
compared_float32_speedup: float = 0.0
57+
compared_bfloat16_speedup: float = 0.0
58+
compared_int16_speedup: float = 0.0
59+
compared_int32_speedup: float = 0.0
60+
compared_bool_speedup: float = 0.0
61+
compared_cfloat_speedup: float = 0.0
5062
all_tests_passed: bool = False
5163

5264
def __str__(self) -> str:
5365
all_shapes_status = "yes" if self.all_tests_passed else "no"
5466
return (
55-
f"{self.op_name:<30} "
56-
f"{self.float16_speedup:<20.6f} "
57-
f"{self.float32_speedup:<20.6f} "
58-
f"{self.bfloat16_speedup:<20.6f} "
59-
f"{self.int16_speedup:<20.6f} "
60-
f"{self.int32_speedup:<20.6f} "
61-
f"{self.bool_speedup:<20.6f} "
62-
f"{self.cfloat_speedup:<20.6f}"
63-
f"{all_shapes_status:<20}"
67+
(
68+
f"{self.op_name:<30} "
69+
f"{self.float16_speedup:<20.6f} "
70+
f"{self.float32_speedup:<20.6f} "
71+
f"{self.bfloat16_speedup:<20.6f} "
72+
f"{self.int16_speedup:<20.6f} "
73+
f"{self.int32_speedup:<20.6f} "
74+
f"{self.bool_speedup:<20.6f} "
75+
f"{self.cfloat_speedup:<20.6f}"
76+
f"{self.compared_float16_speedup:<20.6f}"
77+
f"{self.compared_float32_speedup:<20.6f}"
78+
f"{self.compared_bfloat16_speedup:<20.6f}"
79+
f"{self.compared_int16_speedup:<20.6f}"
80+
f"{self.compared_int32_speedup:<20.6f}"
81+
f"{self.compared_bool_speedup:<20.6f}"
82+
f"{self.compared_cfloat_speedup:<20.6f}"
83+
f"{all_shapes_status:<20}"
84+
)
85+
if ENABLE_COMPARE
86+
else (
87+
f"{self.op_name:<30} "
88+
f"{self.float16_speedup:<20.6f} "
89+
f"{self.float32_speedup:<20.6f} "
90+
f"{self.bfloat16_speedup:<20.6f} "
91+
f"{self.int16_speedup:<20.6f} "
92+
f"{self.int32_speedup:<20.6f} "
93+
f"{self.bool_speedup:<20.6f} "
94+
f"{self.cfloat_speedup:<20.6f}"
95+
f"{all_shapes_status:<20}"
96+
)
6497
)
6598

6699

@@ -103,6 +136,56 @@ def parse_log(log_file_path: str) -> List[BenchmarkResult]:
103136
return benchmark_results
104137

105138

139+
def get_key_by_op_dtype_shape(op_name, dtype, shape):
140+
return hex(hash((hash(op_name), hash(dtype), hash(shape))))
141+
142+
143+
def parse_log_to_dict(log_file_path: str) -> Dict[int, Any]:
144+
with open(log_file_path, "r") as file:
145+
log_lines = [
146+
line
147+
for line in file.read().strip().split("\n")
148+
if line.startswith("[INFO]")
149+
]
150+
151+
# dict(op_name, dict(dtype, dict(shape, latency))
152+
benchmark_results = dict()
153+
for line in log_lines:
154+
if line.startswith("[INFO]"):
155+
json_str = line[len("[INFO] ") :]
156+
data = json.loads(json_str)
157+
op_name = (data["op_name"],)
158+
dtype = (data["dtype"],)
159+
mode = (data["mode"],)
160+
level = (data["level"],)
161+
benchmark_result = BenchmarkResult(
162+
op_name,
163+
dtype,
164+
mode,
165+
level,
166+
result=[
167+
BenchmarkMetrics(
168+
legacy_shape=metric.get("legacy_shape"),
169+
shape_detail=metric.get("shape_detail", []),
170+
latency_base=metric.get("latency_base"),
171+
latency=metric.get("latency"),
172+
speedup=metric.get("speedup"),
173+
accuracy=metric.get("accuracy"),
174+
tflops=metric.get("tflops"),
175+
utilization=metric.get("utilization"),
176+
error_msg=metric.get("error_msg"),
177+
)
178+
for metric in data["result"]
179+
],
180+
)
181+
for result in benchmark_result.result:
182+
key = get_key_by_op_dtype_shape(
183+
op_name[0], dtype[0], str(result.shape_detail)
184+
)
185+
benchmark_results[key] = result.latency
186+
return benchmark_results
187+
188+
106189
def calculate_avg_speedup_over_dtype(metrics):
107190
speedups = [
108191
metric.speedup
@@ -112,6 +195,15 @@ def calculate_avg_speedup_over_dtype(metrics):
112195
return sum(speedups) / len(speedups) if speedups else 0.0
113196

114197

198+
def calculate_avg_compared_speedup_over_dtype(metrics):
199+
compared_speedups = [
200+
metric.compared_speedup
201+
for metric in metrics
202+
if metric.compared_speedup is not None and metric.error_msg is None
203+
]
204+
return sum(compared_speedups) / len(compared_speedups) if compared_speedups else 0.0
205+
206+
115207
def all_benchshape_passed(metrics):
116208
return all(metric.error_msg is None for metric in metrics)
117209

@@ -132,6 +224,7 @@ def summary_for_plot(benchmark_results):
132224
for item in benchmark_results:
133225
op_name = item.op_name
134226
avg_speedup = calculate_avg_speedup_over_dtype(item.result)
227+
avg_compared_speedup = calculate_avg_compared_speedup_over_dtype(item.result)
135228
cur_op_summary = summary[op_name]
136229
cur_op_summary.op_name = op_name
137230
cur_op_summary.all_tests_passed = all_benchshape_passed(item.result)
@@ -140,20 +233,47 @@ def summary_for_plot(benchmark_results):
140233
dtype_mapping.get(item.dtype, "float16_speedup"),
141234
avg_speedup,
142235
)
236+
if ENABLE_COMPARE:
237+
setattr(
238+
summary[op_name],
239+
"compared_" + dtype_mapping.get(item.dtype, "float16_speedup"),
240+
avg_compared_speedup,
241+
)
143242

144243
# sort the keys based on `op_name`
145244
sorted_summary = sorted(summary.values(), key=lambda x: x.op_name)
146245

147246
header = (
148-
f"{'op_name':<30} "
149-
f"{'float16_speedup':<20} "
150-
f"{'float32_speedup':<20} "
151-
f"{'bfloat16_speedup':<20} "
152-
f"{'int16_speedup':<20} "
153-
f"{'int32_speedup':<20} "
154-
f"{'bool_speedup':<20} "
155-
f"{'cfloat_speedup':<20}"
156-
f"{'all_tests_passed':<20}"
247+
(
248+
f"{'op_name':<30} "
249+
f"{'float16_speedup':<20} "
250+
f"{'float32_speedup':<20} "
251+
f"{'bfloat16_speedup':<20} "
252+
f"{'int16_speedup':<20} "
253+
f"{'int32_speedup':<20} "
254+
f"{'bool_speedup':<20} "
255+
f"{'cfloat_speedup':<20}"
256+
f"{'comp_fp16_speedup':<20}"
257+
f"{'comp_fp32_speedup':<20}"
258+
f"{'comp_bf16_speedup':<20}"
259+
f"{'comp_int16_speedup':<20}"
260+
f"{'comp_int32_speedup':<20}"
261+
f"{'comp_bool_speedup':<20}"
262+
f"{'comp_cfloat_speedup':<20}"
263+
f"{'all_tests_passed':<20}"
264+
)
265+
if ENABLE_COMPARE
266+
else (
267+
f"{'op_name':<30} "
268+
f"{'float16_speedup':<20} "
269+
f"{'float32_speedup':<20} "
270+
f"{'bfloat16_speedup':<20} "
271+
f"{'int16_speedup':<20} "
272+
f"{'int32_speedup':<20} "
273+
f"{'bool_speedup':<20} "
274+
f"{'cfloat_speedup':<20}"
275+
f"{'all_tests_passed':<20}"
276+
)
157277
)
158278

159279
print(header)
@@ -163,6 +283,19 @@ def summary_for_plot(benchmark_results):
163283
return summary
164284

165285

286+
def compare_main(log_file_a, log_file_b):
287+
result_a = parse_log(log_file_a)
288+
result_b = parse_log_to_dict(log_file_b)
289+
for result in result_a:
290+
for sub_result in result.result:
291+
key = get_key_by_op_dtype_shape(
292+
result.op_name, result.dtype, str(sub_result.shape_detail)
293+
)
294+
sub_result.compared_speedup = result_b.get(key, 0) / sub_result.latency
295+
296+
summary_for_plot(result_a)
297+
298+
166299
def main(log_file_path):
167300
result = parse_log(log_file_path)
168301
summary_for_plot(result)
@@ -171,6 +304,17 @@ def main(log_file_path):
171304
if __name__ == "__main__":
172305
parser = argparse.ArgumentParser(description="Parse benchmark log file.")
173306
parser.add_argument("log_file_path", type=str, help="Path to the log file.")
307+
parser.add_argument(
308+
"--compare",
309+
"-c",
310+
type=str,
311+
default="",
312+
help="Path to a log file with baseline data to get speedup statistics across 2 log files",
313+
)
174314
args = parser.parse_args()
175315

176-
main(args.log_file_path)
316+
if not args.compare == "":
317+
ENABLE_COMPARE = True
318+
compare_main(args.log_file_path, args.compare)
319+
else:
320+
main(args.log_file_path)

0 commit comments

Comments
 (0)