Skip to content

Commit d500fd5

Browse files
committed
Fix 'tuple' object has no attribute 'to_dict' for bert
Use custom_write_calibration_table for migraphx
1 parent afba506 commit d500fd5

File tree

1 file changed

+134
-1
lines changed

1 file changed

+134
-1
lines changed

quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py

+134-1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,105 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
277277

278278
return not_selected_op1_nodes
279279

280+
def custom_write_calibration_table(calibration_cache, dir="."):
281+
"""
282+
Helper function to write calibration table to files.
283+
"""
284+
285+
import json
286+
import logging
287+
import flatbuffers
288+
import numpy as np
289+
290+
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
291+
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
292+
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData
293+
294+
logging.info(f"calibration cache: {calibration_cache}")
295+
296+
class MyEncoder(json.JSONEncoder):
297+
def default(self, obj):
298+
if isinstance(obj, (TensorData, TensorsData)):
299+
return obj.to_dict()
300+
if isinstance(obj, TensorDataWrapper):
301+
return obj.data_dict
302+
if isinstance(obj, np.ndarray):
303+
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
304+
if isinstance(obj, CalibrationMethod):
305+
return {"CLS": obj.__class__.__name__, "value": str(obj)}
306+
return json.JSONEncoder.default(self, obj)
307+
308+
json_data = json.dumps(calibration_cache, cls=MyEncoder)
309+
310+
with open(os.path.join(dir, "calibration.json"), "w") as file:
311+
file.write(json_data) # use `json.loads` to do the reverse
312+
313+
# Serialize data using FlatBuffers
314+
zero = np.array(0)
315+
builder = flatbuffers.Builder(1024)
316+
key_value_list = []
317+
318+
for key in sorted(calibration_cache.keys()):
319+
values = calibration_cache[key]
320+
d_values = values.to_dict()
321+
322+
highest = d_values.get("highest", zero)
323+
lowest = d_values.get("lowest", zero)
324+
325+
highest_val = highest.item() if hasattr(highest, "item") else float(highest)
326+
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)
327+
328+
floats = [float(highest_val), float(lowest_val)]
329+
330+
value = str(max(floats))
331+
332+
flat_key = builder.CreateString(key)
333+
flat_value = builder.CreateString(value)
334+
335+
KeyValue.KeyValueStart(builder)
336+
KeyValue.KeyValueAddKey(builder, flat_key)
337+
KeyValue.KeyValueAddValue(builder, flat_value)
338+
key_value = KeyValue.KeyValueEnd(builder)
339+
340+
key_value_list.append(key_value)
341+
342+
343+
TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
344+
for key_value in key_value_list:
345+
builder.PrependUOffsetTRelative(key_value)
346+
main_dict = builder.EndVector()
347+
348+
TrtTable.TrtTableStart(builder)
349+
TrtTable.TrtTableAddDict(builder, main_dict)
350+
cal_table = TrtTable.TrtTableEnd(builder)
351+
352+
builder.Finish(cal_table)
353+
buf = builder.Output()
354+
355+
with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
356+
file.write(buf)
357+
358+
# Deserialize data (for validation)
359+
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
360+
cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
361+
dict_len = cal_table.DictLength()
362+
for i in range(dict_len):
363+
key_value = cal_table.Dict(i)
364+
logging.info(key_value.Key())
365+
logging.info(key_value.Value())
366+
367+
# write plain text
368+
with open(os.path.join(dir, "calibration.cache"), "w") as file:
369+
for key in sorted(calibration_cache.keys()):
370+
values = calibration_cache[key]
371+
d_values = values.to_dict()
372+
floats = [
373+
float(d_values.get("highest", zero).item()),
374+
float(d_values.get("lowest", zero).item()),
375+
]
376+
value = key + " " + str(max(floats))
377+
file.write(value)
378+
file.write("\n")
280379

281380
def parse_input_args():
282381
parser = argparse.ArgumentParser()
@@ -553,8 +652,42 @@ def output_run_config(flags, samples):
553652
for k, v in compute_range.data.items():
554653
json_compute_range[k] = (float(v.range_value[0]), float(v.range_value[1]))
555654

655+
print("Writing calibration table")
656+
try:
657+
write_calibration_table(json_compute_range)
658+
except AttributeError as e:
659+
class TensorDataWrapper:
660+
def __init__(self, data_dict):
661+
self.data_dict = data_dict
662+
663+
def to_dict(self):
664+
return self.data_dict
665+
666+
def __repr__(self):
667+
return repr(self.data_dict)
668+
669+
def __serializable__(self):
670+
return self.data_dict
671+
672+
calibration_data = {}
673+
for k, v in compute_range.data.items():
674+
if hasattr(v, 'to_dict'):
675+
tensor_dict = v.to_dict()
676+
processed_dict = {}
677+
for dk, dv in tensor_dict.items():
678+
if isinstance(dv, np.ndarray):
679+
processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist()
680+
elif isinstance(dv, np.number):
681+
processed_dict[dk] = dv.item()
682+
else:
683+
processed_dict[dk] = dv
684+
calibration_data[k] = TensorDataWrapper(processed_dict)
685+
else:
686+
calibration_data[k] = v
687+
688+
print("Using custom calibration table function")
689+
custom_write_calibration_table(calibration_data)
556690

557-
write_calibration_table(json_compute_range)
558691
print("Calibration is done. Calibration cache is saved to calibration.json")
559692

560693
model_quants = model_quants + "_int8"

0 commit comments

Comments
 (0)