Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,121 @@
import onnxruntime
from onnxruntime.quantization import CalibrationDataReader, create_calibrator, write_calibration_table

def custom_write_calibration_table(calibration_cache, filename):
"""
Helper function to write calibration table to files.
"""

import json
import logging
import flatbuffers
import numpy as np

import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData

logging.info(f"calibration cache: {calibration_cache}")

class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (TensorData, TensorsData)):
return obj.to_dict()
if isinstance(obj, TensorDataWrapper):
return obj.data_dict
if isinstance(obj, np.ndarray):
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
if isinstance(obj, CalibrationMethod):
return {"CLS": obj.__class__.__name__, "value": str(obj)}
return json.JSONEncoder.default(self, obj)

json_data = json.dumps(calibration_cache, cls=MyEncoder)

with open(filename, "w") as file:
file.write(json_data) # use `json.loads` to do the reverse

# Serialize data using FlatBuffers
zero = np.array(0)
builder = flatbuffers.Builder(1024)
key_value_list = []

for key in sorted(calibration_cache.keys()):
values = calibration_cache[key]
d_values = values.to_dict()

highest = d_values.get("highest", zero)
lowest = d_values.get("lowest", zero)

highest_val = highest.item() if hasattr(highest, "item") else float(highest)
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)

floats = [float(highest_val), float(lowest_val)]

value = str(max(floats))

flat_key = builder.CreateString(key)
flat_value = builder.CreateString(value)

KeyValue.KeyValueStart(builder)
KeyValue.KeyValueAddKey(builder, flat_key)
KeyValue.KeyValueAddValue(builder, flat_value)
key_value = KeyValue.KeyValueEnd(builder)

key_value_list.append(key_value)


TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
for key_value in key_value_list:
builder.PrependUOffsetTRelative(key_value)
main_dict = builder.EndVector()

TrtTable.TrtTableStart(builder)
TrtTable.TrtTableAddDict(builder, main_dict)
cal_table = TrtTable.TrtTableEnd(builder)

builder.Finish(cal_table)
buf = builder.Output()

with open(filename, "wb") as file:
file.write(buf)

# Deserialize data (for validation)
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
dict_len = cal_table.DictLength()
for i in range(dict_len):
key_value = cal_table.Dict(i)
logging.info(key_value.Key())
logging.info(key_value.Value())

# write plain text
with open(filename + ".cache", "w") as file:
for key in sorted(calibration_cache.keys()):
values = calibration_cache[key]
d_values = values.to_dict()
highest = d_values.get("highest", zero)
lowest = d_values.get("lowest", zero)

highest_val = highest.item() if hasattr(highest, "item") else float(highest)
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)

floats = [float(highest_val), float(lowest_val)]

value = key + " " + str(max(floats))
file.write(value)
file.write("\n")


def parse_input_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model",
required=False,
default='./resnet50-v2-7.onnx',
help='Target DIR for model. Default is ./resnet50-v2-7.onnx',
)

parser.add_argument(
"--fp16",
action="store_true",
Expand All @@ -29,6 +141,14 @@ def parse_input_args():
help='Perform no quantization',
)

parser.add_argument(
"--fp8",
action="store_true",
required=False,
default=False,
help='Perform fp8 quantizaton instead of int8',
)

parser.add_argument(
"--image_dir",
required=False,
Expand All @@ -48,6 +168,29 @@ def parse_input_args():
help='Size of images for calibration',
type=int)

parser.add_argument(
"--exhaustive_tune",
action="store_true",
required=False,
default=False,
help='Enable MIGraphX Exhaustive tune before compile. Default False',
)

parser.add_argument(
"--cache",
action="store_true",
required=False,
default=True,
help='cache the compiled model between runs. Saves quantization and compile time. Default true',
)

parser.add_argument(
"--cache_name",
required=False,
default="./cached_model.mxr",
help='Name and path of the compiled model cache. Default: ./cached_model.mxr',
)

return parser.parse_args()

class ImageNetDataReader(CalibrationDataReader):
Expand Down Expand Up @@ -255,6 +398,7 @@ class ImageClassificationEvaluator:
def __init__(self,
model_path,
synset_id,
flags,
data_reader: CalibrationDataReader,
providers=["MIGraphXExecutionProvider"]):
'''
Expand All @@ -276,10 +420,21 @@ def get_result(self):

def predict(self):
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 0
sess_options.log_verbosity_level = 0
sess_options.log_severity_level = 2
sess_options.log_verbosity_level = 2
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers)
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options,
providers=[("MIGraphXExecutionProvider",
{"migraphx_fp8_enable": flags.fp8 and not flags.fp32,
"migraphx_int8_enable": not (flags.fp8 or flags.fp32),
"migraphx_fp16_enable": flags.fp16 and not flags.fp32,
"migraphx_int8_calibration_table_name": flags.calibration_table,
"migraphx_use_native_calibration_table": flags.native_calibration_table,
"migraphx_save_compiled_model": flags.cache,
"migraphx_save_model_path": flags.cache_name,
"migraphx_load_compiled_model": flags.cache,
"migraphx_load_model_path": flags.cache_name,
"migraphx_exhaustive_tune": flags.exhaustive_tune})])

inference_outputs_list = []
while True:
Expand Down Expand Up @@ -362,21 +517,31 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
flags = parse_input_args()

# Dataset settings
model_path = "./resnet50-v2-7.onnx"
model_path = flags.model
ilsvrc2012_dataset_path = flags.image_dir
augmented_model_path = "./augmented_model.onnx"
batch_size = flags.batch
calibration_dataset_size = 0 if flags.fp32 else flags.cal_size # Size of dataset for calibration

precision=""

if not (flags.fp8 or flags.fp32):
precision = precision + "_int8"

if flags.fp8 and not flags.fp32:
precision = precision + "_fp8"

if flags.fp16 and not flags.fp32:
precision = "_fp16" + precision

calibration_table_generation_enable = False
if not flags.fp32:
# INT8 calibration setting
calibration_table_generation_enable = True # Enable/Disable INT8 calibration

# MIGraphX EP INT8 settings
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision
os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name
flags.calibration_table = "calibration_cal"+ str(flags.cal_size) + precision + ".flatbuffers"
flags.native_calibration_table = "False"
if os.path.isfile("./" + flags.calibration_table):
calibration_table_generation = False
print("Found previous calibration: " + flags.calibration_table + "Skipping generating table")

execution_provider = ["MIGraphXExecutionProvider"]

Expand All @@ -396,25 +561,46 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
start_index=0,
end_index=calibration_dataset_size,
stride=calibration_dataset_size,
batch_size=batch_size,
batch_size=1,
model_path=augmented_model_path,
input_name=input_name)
calibrator.collect_data(data_reader)
cal_tensors = calibrator.compute_data()

serial_cal_tensors = {}
for keys, values in cal_tensors.data.items():
serial_cal_tensors[keys] = [float(x[0]) for x in values.range_value]
class TensorDataWrapper:
def __init__(self, data_dict):
self.data_dict = data_dict

def to_dict(self):
return self.data_dict

def __repr__(self):
return repr(self.data_dict)

def __serializable__(self):
return self.data_dict

calibration_data = {}
for k, v in cal_tensors.data.items():
if hasattr(v, 'to_dict'):
tensor_dict = v.to_dict()
processed_dict = {}
for dk, dv in tensor_dict.items():
if isinstance(dv, np.ndarray):
processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist()
elif isinstance(dv, np.number):
processed_dict[dk] = dv.item()
else:
processed_dict[dk] = dv
calibration_data[k] = TensorDataWrapper(processed_dict)
else:
calibration_data[k] = v

print("Writing calibration table")
write_calibration_table(serial_cal_tensors)
print("Writing calibration table to:" + flags.calibration_table)
custom_write_calibration_table(calibration_data, flags.calibration_table)
os.rename("./calibration.flatbuffers", flags.calibration_table)
print("Write complete")

if flags.fp16:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1"
else:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"

# Run prediction in MIGraphX EP138G
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
start_index=calibration_dataset_size,
Expand All @@ -427,14 +613,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size,
prediction_dataset_size) # Generate synset id
print("Prepping Evalulator")
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider)
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, flags, data_reader, providers=execution_provider)
print("Performing Predictions")
evaluator.predict()
print("Read out answer")
result = evaluator.get_result()
evaluator.evaluate(result)

#Set OS flags to off to ensure we don't interfere with other test runs

os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0"
Loading