32
32
import json
33
33
from collections import defaultdict
34
34
from dataclasses import dataclass
35
- from typing import List
35
+ from typing import Any , Dict , List
36
36
37
37
from attri_util import BenchmarkMetrics , BenchmarkResult
38
38
39
+ # to enable log files crossing speedup calculation
40
+ ENABLE_COMPARE = False
41
+
39
42
40
43
@dataclass
41
44
class SummaryResultOverDtype :
@@ -47,20 +50,50 @@ class SummaryResultOverDtype:
47
50
int32_speedup : float = 0.0
48
51
bool_speedup : float = 0.0
49
52
cfloat_speedup : float = 0.0
53
+
54
+ # to calculate the speedup across log files.
55
+ compared_float16_speedup : float = 0.0
56
+ compared_float32_speedup : float = 0.0
57
+ compared_bfloat16_speedup : float = 0.0
58
+ compared_int16_speedup : float = 0.0
59
+ compared_int32_speedup : float = 0.0
60
+ compared_bool_speedup : float = 0.0
61
+ compared_cfloat_speedup : float = 0.0
50
62
all_tests_passed : bool = False
51
63
52
64
def __str__ (self ) -> str :
53
65
all_shapes_status = "yes" if self .all_tests_passed else "no"
54
66
return (
55
- f"{ self .op_name :<30} "
56
- f"{ self .float16_speedup :<20.6f} "
57
- f"{ self .float32_speedup :<20.6f} "
58
- f"{ self .bfloat16_speedup :<20.6f} "
59
- f"{ self .int16_speedup :<20.6f} "
60
- f"{ self .int32_speedup :<20.6f} "
61
- f"{ self .bool_speedup :<20.6f} "
62
- f"{ self .cfloat_speedup :<20.6f} "
63
- f"{ all_shapes_status :<20} "
67
+ (
68
+ f"{ self .op_name :<30} "
69
+ f"{ self .float16_speedup :<20.6f} "
70
+ f"{ self .float32_speedup :<20.6f} "
71
+ f"{ self .bfloat16_speedup :<20.6f} "
72
+ f"{ self .int16_speedup :<20.6f} "
73
+ f"{ self .int32_speedup :<20.6f} "
74
+ f"{ self .bool_speedup :<20.6f} "
75
+ f"{ self .cfloat_speedup :<20.6f} "
76
+ f"{ self .compared_float16_speedup :<20.6f} "
77
+ f"{ self .compared_float32_speedup :<20.6f} "
78
+ f"{ self .compared_bfloat16_speedup :<20.6f} "
79
+ f"{ self .compared_int16_speedup :<20.6f} "
80
+ f"{ self .compared_int32_speedup :<20.6f} "
81
+ f"{ self .compared_bool_speedup :<20.6f} "
82
+ f"{ self .compared_cfloat_speedup :<20.6f} "
83
+ f"{ all_shapes_status :<20} "
84
+ )
85
+ if ENABLE_COMPARE
86
+ else (
87
+ f"{ self .op_name :<30} "
88
+ f"{ self .float16_speedup :<20.6f} "
89
+ f"{ self .float32_speedup :<20.6f} "
90
+ f"{ self .bfloat16_speedup :<20.6f} "
91
+ f"{ self .int16_speedup :<20.6f} "
92
+ f"{ self .int32_speedup :<20.6f} "
93
+ f"{ self .bool_speedup :<20.6f} "
94
+ f"{ self .cfloat_speedup :<20.6f} "
95
+ f"{ all_shapes_status :<20} "
96
+ )
64
97
)
65
98
66
99
@@ -103,6 +136,56 @@ def parse_log(log_file_path: str) -> List[BenchmarkResult]:
103
136
return benchmark_results
104
137
105
138
139
+ def get_key_by_op_dtype_shape (op_name , dtype , shape ):
140
+ return hex (hash ((hash (op_name ), hash (dtype ), hash (shape ))))
141
+
142
+
143
+ def parse_log_to_dict (log_file_path : str ) -> Dict [int , Any ]:
144
+ with open (log_file_path , "r" ) as file :
145
+ log_lines = [
146
+ line
147
+ for line in file .read ().strip ().split ("\n " )
148
+ if line .startswith ("[INFO]" )
149
+ ]
150
+
151
+ # dict(op_name, dict(dtype, dict(shape, latency))
152
+ benchmark_results = dict ()
153
+ for line in log_lines :
154
+ if line .startswith ("[INFO]" ):
155
+ json_str = line [len ("[INFO] " ) :]
156
+ data = json .loads (json_str )
157
+ op_name = (data ["op_name" ],)
158
+ dtype = (data ["dtype" ],)
159
+ mode = (data ["mode" ],)
160
+ level = (data ["level" ],)
161
+ benchmark_result = BenchmarkResult (
162
+ op_name ,
163
+ dtype ,
164
+ mode ,
165
+ level ,
166
+ result = [
167
+ BenchmarkMetrics (
168
+ legacy_shape = metric .get ("legacy_shape" ),
169
+ shape_detail = metric .get ("shape_detail" , []),
170
+ latency_base = metric .get ("latency_base" ),
171
+ latency = metric .get ("latency" ),
172
+ speedup = metric .get ("speedup" ),
173
+ accuracy = metric .get ("accuracy" ),
174
+ tflops = metric .get ("tflops" ),
175
+ utilization = metric .get ("utilization" ),
176
+ error_msg = metric .get ("error_msg" ),
177
+ )
178
+ for metric in data ["result" ]
179
+ ],
180
+ )
181
+ for result in benchmark_result .result :
182
+ key = get_key_by_op_dtype_shape (
183
+ op_name [0 ], dtype [0 ], str (result .shape_detail )
184
+ )
185
+ benchmark_results [key ] = result .latency
186
+ return benchmark_results
187
+
188
+
106
189
def calculate_avg_speedup_over_dtype (metrics ):
107
190
speedups = [
108
191
metric .speedup
@@ -112,6 +195,15 @@ def calculate_avg_speedup_over_dtype(metrics):
112
195
return sum (speedups ) / len (speedups ) if speedups else 0.0
113
196
114
197
198
+ def calculate_avg_compared_speedup_over_dtype (metrics ):
199
+ compared_speedups = [
200
+ metric .compared_speedup
201
+ for metric in metrics
202
+ if metric .compared_speedup is not None and metric .error_msg is None
203
+ ]
204
+ return sum (compared_speedups ) / len (compared_speedups ) if compared_speedups else 0.0
205
+
206
+
115
207
def all_benchshape_passed (metrics ):
116
208
return all (metric .error_msg is None for metric in metrics )
117
209
@@ -132,6 +224,7 @@ def summary_for_plot(benchmark_results):
132
224
for item in benchmark_results :
133
225
op_name = item .op_name
134
226
avg_speedup = calculate_avg_speedup_over_dtype (item .result )
227
+ avg_compared_speedup = calculate_avg_compared_speedup_over_dtype (item .result )
135
228
cur_op_summary = summary [op_name ]
136
229
cur_op_summary .op_name = op_name
137
230
cur_op_summary .all_tests_passed = all_benchshape_passed (item .result )
@@ -140,20 +233,47 @@ def summary_for_plot(benchmark_results):
140
233
dtype_mapping .get (item .dtype , "float16_speedup" ),
141
234
avg_speedup ,
142
235
)
236
+ if ENABLE_COMPARE :
237
+ setattr (
238
+ summary [op_name ],
239
+ "compared_" + dtype_mapping .get (item .dtype , "float16_speedup" ),
240
+ avg_compared_speedup ,
241
+ )
143
242
144
243
# sort the keys based on `op_name`
145
244
sorted_summary = sorted (summary .values (), key = lambda x : x .op_name )
146
245
147
246
header = (
148
- f"{ 'op_name' :<30} "
149
- f"{ 'float16_speedup' :<20} "
150
- f"{ 'float32_speedup' :<20} "
151
- f"{ 'bfloat16_speedup' :<20} "
152
- f"{ 'int16_speedup' :<20} "
153
- f"{ 'int32_speedup' :<20} "
154
- f"{ 'bool_speedup' :<20} "
155
- f"{ 'cfloat_speedup' :<20} "
156
- f"{ 'all_tests_passed' :<20} "
247
+ (
248
+ f"{ 'op_name' :<30} "
249
+ f"{ 'float16_speedup' :<20} "
250
+ f"{ 'float32_speedup' :<20} "
251
+ f"{ 'bfloat16_speedup' :<20} "
252
+ f"{ 'int16_speedup' :<20} "
253
+ f"{ 'int32_speedup' :<20} "
254
+ f"{ 'bool_speedup' :<20} "
255
+ f"{ 'cfloat_speedup' :<20} "
256
+ f"{ 'comp_fp16_speedup' :<20} "
257
+ f"{ 'comp_fp32_speedup' :<20} "
258
+ f"{ 'comp_bf16_speedup' :<20} "
259
+ f"{ 'comp_int16_speedup' :<20} "
260
+ f"{ 'comp_int32_speedup' :<20} "
261
+ f"{ 'comp_bool_speedup' :<20} "
262
+ f"{ 'comp_cfloat_speedup' :<20} "
263
+ f"{ 'all_tests_passed' :<20} "
264
+ )
265
+ if ENABLE_COMPARE
266
+ else (
267
+ f"{ 'op_name' :<30} "
268
+ f"{ 'float16_speedup' :<20} "
269
+ f"{ 'float32_speedup' :<20} "
270
+ f"{ 'bfloat16_speedup' :<20} "
271
+ f"{ 'int16_speedup' :<20} "
272
+ f"{ 'int32_speedup' :<20} "
273
+ f"{ 'bool_speedup' :<20} "
274
+ f"{ 'cfloat_speedup' :<20} "
275
+ f"{ 'all_tests_passed' :<20} "
276
+ )
157
277
)
158
278
159
279
print (header )
@@ -163,6 +283,19 @@ def summary_for_plot(benchmark_results):
163
283
return summary
164
284
165
285
286
+ def compare_main (log_file_a , log_file_b ):
287
+ result_a = parse_log (log_file_a )
288
+ result_b = parse_log_to_dict (log_file_b )
289
+ for result in result_a :
290
+ for sub_result in result .result :
291
+ key = get_key_by_op_dtype_shape (
292
+ result .op_name , result .dtype , str (sub_result .shape_detail )
293
+ )
294
+ sub_result .compared_speedup = result_b .get (key , 0 ) / sub_result .latency
295
+
296
+ summary_for_plot (result_a )
297
+
298
+
166
299
def main (log_file_path ):
167
300
result = parse_log (log_file_path )
168
301
summary_for_plot (result )
@@ -171,6 +304,17 @@ def main(log_file_path):
171
304
if __name__ == "__main__" :
172
305
parser = argparse .ArgumentParser (description = "Parse benchmark log file." )
173
306
parser .add_argument ("log_file_path" , type = str , help = "Path to the log file." )
307
+ parser .add_argument (
308
+ "--compare" ,
309
+ "-c" ,
310
+ type = str ,
311
+ default = "" ,
312
+ help = "Path to a log file with baseline data to get speedup statistics across 2 log files" ,
313
+ )
174
314
args = parser .parse_args ()
175
315
176
- main (args .log_file_path )
316
+ if not args .compare == "" :
317
+ ENABLE_COMPARE = True
318
+ compare_main (args .log_file_path , args .compare )
319
+ else :
320
+ main (args .log_file_path )
0 commit comments