Skip to content

Update tests to use fp8 #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
def parse_input_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model",
required=False,
default='./resnet50-v2-7.onnx',
help='Target DIR for model. Default is ./resnet50-v2-7.onnx',
)

parser.add_argument(
"--fp16",
action="store_true",
Expand All @@ -29,6 +36,14 @@ def parse_input_args():
help='Perform no quantization',
)

parser.add_argument(
"--fp8",
action="store_true",
required=False,
default=False,
help='Perform fp8 quantizaton instead of int8',
)

parser.add_argument(
"--image_dir",
required=False,
Expand All @@ -48,6 +63,29 @@ def parse_input_args():
help='Size of images for calibration',
type=int)

parser.add_argument(
"--exhaustive_tune",
action="store_true",
required=False,
default=False,
help='Enable MIGraphX Exhaustive tune before compile. Default False',
)

parser.add_argument(
"--cache",
action="store_true",
required=False,
default=True,
help='cache the compiled model between runs. Saves quantization and compile time. Default true',
)

parser.add_argument(
"--cache_name",
required=False,
default="./cached_model.mxr",
help='Name and path of the compiled model cache. Default: ./cached_model.mxr',
)

return parser.parse_args()

class ImageNetDataReader(CalibrationDataReader):
Expand Down Expand Up @@ -255,6 +293,7 @@ class ImageClassificationEvaluator:
def __init__(self,
model_path,
synset_id,
flags,
data_reader: CalibrationDataReader,
providers=["MIGraphXExecutionProvider"]):
'''
Expand All @@ -276,10 +315,21 @@ def get_result(self):

def predict(self):
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 0
sess_options.log_verbosity_level = 0
sess_options.log_severity_level = 2
sess_options.log_verbosity_level = 2
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers)
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options,
providers=[("MIGraphXExecutionProvider",
{"migraphx_fp8_enable": flags.fp8 and not flags.fp32,
"migraphx_int8_enable": not (flags.fp8 or flags.fp32),
"migraphx_fp16_enable": flags.fp16 and not flags.fp32,
"migraphx_int8_calibration_table_name": flags.calibration_table,
"migraphx_use_native_calibration_table": flags.native_calibration_table,
"migraphx_save_compiled_model": flags.cache,
"migraphx_save_model_path": flags.cache_name,
"migraphx_load_compiled_model": flags.cache,
"migraphx_load_model_path": flags.cache_name,
"migraphx_exhaustive_tune": flags.exhaustive_tune})])

inference_outputs_list = []
while True:
Expand Down Expand Up @@ -362,21 +412,31 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
flags = parse_input_args()

# Dataset settings
model_path = "./resnet50-v2-7.onnx"
model_path = flags.model
ilsvrc2012_dataset_path = flags.image_dir
augmented_model_path = "./augmented_model.onnx"
batch_size = flags.batch
calibration_dataset_size = 0 if flags.fp32 else flags.cal_size # Size of dataset for calibration

precision=""

if not (flags.fp8 or flags.fp32):
precision = precision + "_int8"

if flags.fp8 and not flags.fp32:
precision = precision + "_fp8"

if flags.fp16 and not flags.fp32:
precision = "_fp16" + precision

calibration_table_generation_enable = False
if not flags.fp32:
# INT8 calibration setting
calibration_table_generation_enable = True # Enable/Disable INT8 calibration

# MIGraphX EP INT8 settings
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision
os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name
flags.calibration_table = "calibration_cal"+ str(flags.cal_size) + precision + ".flatbuffers"
flags.native_calibration_table = "False"
if os.path.isfile("./" + flags.calibration_table):
calibration_table_generation = False
print("Found previous calibration: " + flags.calibration_table + "Skipping generating table")

execution_provider = ["MIGraphXExecutionProvider"]

Expand Down Expand Up @@ -406,15 +466,11 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
for keys, values in cal_tensors.data.items():
serial_cal_tensors[keys] = [float(x[0]) for x in values.range_value]

print("Writing calibration table")
print("Writing calibration table to:" + flags.calibration_table)
write_calibration_table(serial_cal_tensors)
os.rename("./calibration.flatbuffers", flags.calibration_table)
print("Write complete")

if flags.fp16:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1"
else:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"

# Run prediction in MIGraphX EP138G
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
start_index=calibration_dataset_size,
Expand All @@ -427,14 +483,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size,
prediction_dataset_size) # Generate synset id
print("Prepping Evalulator")
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider)
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, flags, data_reader, providers=execution_provider)
print("Performing Predictions")
evaluator.predict()
print("Read out answer")
result = evaluator.get_result()
evaluator.evaluate(result)

#Set OS flags to off to ensure we don't interfere with other test runs

os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0"
Loading