diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index ac99de348f612..50b0bd08ae360 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -3,7 +3,11 @@ CalibrationDataReader, CalibrationMethod, MinMaxCalibrater, + TensorData, + TensorsData, create_calibrator, + load_tensors_data, + save_tensors_data, ) from .qdq_quantizer import QDQQuantizer # noqa: F401 from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 05a5b0873d93d..453fa82188f55 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -7,6 +7,7 @@ import abc import copy import itertools +import json import os import uuid from collections.abc import Sequence @@ -98,6 +99,19 @@ def to_dict(self): data["CLS"] = self.__class__.__name__ return data + @classmethod + def from_dict(cls, d: dict) -> "TensorData": + """Reconstruct a TensorData from a dict produced by to_dict().""" + kwargs = {} + for k, v in d.items(): + if k == "CLS": + continue + value = v + if isinstance(value, dict) and value.get("CLS") == "numpy.array": + value = np.array(value["data"], dtype=np.dtype(value["dtype"])) + kwargs[k] = value + return cls(**kwargs) + class TensorsData: def __init__(self, calibration_method, data: dict[str, TensorData | tuple]): @@ -150,6 +164,18 @@ def to_dict(self): } return data + @classmethod + def from_dict(cls, d: dict) -> "TensorsData": + """Reconstruct a TensorsData from a dict produced by to_dict().""" + method_val = d["calibration_method"] + if isinstance(method_val, dict) and method_val.get("CLS") == "CalibrationMethod": + name = method_val["value"].split(".")[-1] + method = CalibrationMethod[name] + else: + method = method_val + reconstructed = {k: TensorData.from_dict(v) for k, v in d["data"].items()} + return cls(method, reconstructed) + class CalibrationMethod(Enum): MinMax = 0 @@ -184,6 +210,46 @@ def set_range(self, start_index: int, end_index: int): raise NotImplementedError +class _CalibrationCacheEncoder(json.JSONEncoder): + """JSON encoder for calibration cache serialization.""" + + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + return json.JSONEncoder.default(self, obj) + + +def save_tensors_data(tensors_data: "TensorsData", path: "str | Path") -> None: + """Serialize calibration tensor ranges to a JSON file at *path*.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with tmp.open("w") as f: + json.dump(tensors_data, f, cls=_CalibrationCacheEncoder) + f.flush() + os.replace(tmp, path) + + +def load_tensors_data(path: "str | Path") -> "TensorsData": + """Load calibration tensor ranges from a JSON file written by save_tensors_data().""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Calibration cache not found: {path}") + if not path.is_file(): + raise ValueError(f"Calibration cache path is not a file: {path}") + with path.open("r") as f: + d = json.load(f) + return TensorsData.from_dict(d) + + class CalibraterBase: def __init__( self, diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index b8b239b85e7ad..b80451f275acd 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -14,7 +14,14 @@ import onnx -from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator +from .calibrate import ( + CalibrationDataReader, + CalibrationMethod, + TensorsData, + create_calibrator, + load_tensors_data, + save_tensors_data, +) from .onnx_quantizer import ONNXQuantizer from .qdq_quantizer import QDQQuantizer from .quant_utils import ( @@ -479,7 +486,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua def quantize_static( model_input: str | Path | onnx.ModelProto, model_output: str | Path, - calibration_data_reader: CalibrationDataReader, + calibration_data_reader: CalibrationDataReader | None = None, quant_format=QuantFormat.QDQ, op_types_to_quantize=None, per_channel=False, @@ -492,6 +499,7 @@ def quantize_static( calibrate_method=CalibrationMethod.MinMax, calibration_providers=None, extra_options=None, + calibration_cache_path: str | Path | None = None, ): """ Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file @@ -506,7 +514,13 @@ def quantize_static( model_output: file path of quantized model calibration_data_reader: a calibration data reader. It enumerates calibration data and generates inputs for the - original model. + original model. May be None if calibration_cache_path points to an + existing cache file. + calibration_cache_path: optional path to a JSON calibration cache. If + the file already exists, calibration inference is skipped and the + cached tensor ranges are loaded instead. If the file does not yet + exist, calibration runs normally and the result is saved to this + path for future reuse. quant_format: QuantFormat{QOperator, QDQ}. QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. @@ -673,6 +687,11 @@ def quantize_static( } if extra_options.get("SmoothQuant", False): + if calibration_data_reader is None: + raise ValueError( + "SmoothQuant requires a non-None calibration_data_reader; the calibration cache " + "stores per-tensor ranges only and cannot drive the SmoothQuant transform." + ) import importlib # noqa: PLC0415 try: @@ -704,48 +723,68 @@ def inc_dataloader(): if is_model_updated: model = updated_model - with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: - if is_model_updated: - # Update model_input and avoid to use the original one - model_input = copy.deepcopy(model) - - if isinstance(model_input, onnx.ModelProto): - output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix() - onnx.save_model( - model_input, - output_path, - save_as_external_data=True, + _cache_path = Path(calibration_cache_path) if calibration_cache_path is not None else None + if _cache_path is not None and _cache_path.exists() and not _cache_path.is_file(): + raise ValueError(f"calibration_cache_path is not a file: {_cache_path}") + _cache_hit = _cache_path is not None and _cache_path.is_file() + + if _cache_hit: + tensors_range = load_tensors_data(_cache_path) + if tensors_range.calibration_method != calibrate_method: + raise ValueError( + f"Calibration cache at {_cache_path} was produced with " + f"{tensors_range.calibration_method}, but quantize_static was called " + f"with calibrate_method={calibrate_method}. Delete the cache or " + f"pass a matching calibrate_method." + ) + else: + if calibration_data_reader is None: + raise ValueError("Either calibration_data_reader or an existing calibration_cache_path must be provided.") + with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: + if is_model_updated: + # Update model_input and avoid to use the original one + model_input = copy.deepcopy(model) + + if isinstance(model_input, onnx.ModelProto): + output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix() + onnx.save_model( + model_input, + output_path, + save_as_external_data=True, + ) + model_input = output_path + + calibrator = create_calibrator( + Path(model_input), + op_types_to_quantize, + augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(), + calibrate_method=calibrate_method, + use_external_data_format=use_external_data_format, + providers=calibration_providers, + extra_options=calib_extra_options, ) - model_input = output_path - - calibrator = create_calibrator( - Path(model_input), - op_types_to_quantize, - augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(), - calibrate_method=calibrate_method, - use_external_data_format=use_external_data_format, - providers=calibration_providers, - extra_options=calib_extra_options, - ) - - stride = extra_options.get("CalibStridedMinMax", None) - if stride: - total_data_size = len(calibration_data_reader) - if total_data_size % stride != 0: - raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") - for start in range(0, total_data_size, stride): - end_index = start + stride - calibration_data_reader.set_range(start_index=start, end_index=end_index) + stride = extra_options.get("CalibStridedMinMax", None) + if stride: + total_data_size = len(calibration_data_reader) + if total_data_size % stride != 0: + raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") + + for start in range(0, total_data_size, stride): + end_index = start + stride + calibration_data_reader.set_range(start_index=start, end_index=end_index) + calibrator.collect_data(calibration_data_reader) + else: calibrator.collect_data(calibration_data_reader) - else: - calibrator.collect_data(calibration_data_reader) - tensors_range = calibrator.compute_data() - if not isinstance(tensors_range, TensorsData): - raise TypeError( - f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." - ) - del calibrator + tensors_range = calibrator.compute_data() + if not isinstance(tensors_range, TensorsData): + raise TypeError( + f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." + ) + del calibrator + + if _cache_path is not None: + save_tensors_data(tensors_range, _cache_path) check_static_quant_arguments(quant_format, activation_type, weight_type) diff --git a/onnxruntime/test/python/quantization/test_calibration.py b/onnxruntime/test/python/quantization/test_calibration.py index 60c5f9d404258..41f5624bd6eac 100644 --- a/onnxruntime/test/python/quantization/test_calibration.py +++ b/onnxruntime/test/python/quantization/test_calibration.py @@ -14,7 +14,16 @@ from onnx import TensorProto, helper, numpy_helper import onnxruntime -from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod, create_calibrator +from onnxruntime.quantization import quantize_static +from onnxruntime.quantization.calibrate import ( + CalibrationDataReader, + CalibrationMethod, + TensorData, + TensorsData, + create_calibrator, + load_tensors_data, + save_tensors_data, +) def generate_input_initializer(tensor_shape, tensor_dtype, input_name): @@ -528,5 +537,164 @@ def test_compute_data_per_channel(self): np.testing.assert_equal(min_max, tensors_range[output_name].range_value) +class TestCalibrationCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_dir = tempfile.TemporaryDirectory(prefix="test_calibration_cache.") + + @classmethod + def tearDownClass(cls): + cls._tmp_dir.cleanup() + + def _make_simple_model(self, path): + """Build a tiny Conv+Relu model for end-to-end cache tests.""" + vi_input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 3]) + vi_output = helper.make_tensor_value_info("X6", TensorProto.FLOAT, [1, 3, 1, 3]) + w1 = generate_input_initializer([3, 3, 1, 1], np.float32, "W1") + b1 = generate_input_initializer([3], np.float32, "B1") + conv_node = helper.make_node("Conv", ["input", "W1", "B1"], ["X2"], name="Conv1") + relu_node = helper.make_node("Relu", ["X2"], ["X6"], name="Relu1") + graph = helper.make_graph([conv_node, relu_node], "cache_test_graph", [vi_input], [vi_output]) + graph.initializer.add().CopyFrom(w1) + graph.initializer.add().CopyFrom(b1) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + onnx.save(model, path) + + def test_save_load_tensors_data_minmax_roundtrip(self): + td = TensorsData( + CalibrationMethod.MinMax, + {"x": TensorData(lowest=np.array(-1.0, dtype=np.float32), highest=np.array(2.0, dtype=np.float32))}, + ) + cache_path = Path(self._tmp_dir.name) / "minmax_cache.json" + save_tensors_data(td, cache_path) + self.assertTrue(cache_path.exists()) + + loaded = load_tensors_data(cache_path) + self.assertEqual(loaded.calibration_method, CalibrationMethod.MinMax) + self.assertEqual(list(loaded.keys()), ["x"]) + lo, hi = loaded["x"].range_value + np.testing.assert_array_equal(lo, np.array(-1.0, dtype=np.float32)) + np.testing.assert_array_equal(hi, np.array(2.0, dtype=np.float32)) + self.assertEqual(lo.shape, ()) + self.assertEqual(hi.shape, ()) + + def test_save_load_tensors_data_entropy_roundtrip(self): + hist = np.array([1.0, 2.0, 3.0], dtype=np.float32) + hist_edges = np.array([0.0, 1.0, 2.0, 3.0], dtype=np.float32) + td = TensorsData( + CalibrationMethod.Entropy, + { + "y": TensorData( + lowest=np.array(-0.5, dtype=np.float32), + highest=np.array(0.5, dtype=np.float32), + hist=hist, + hist_edges=hist_edges, + ) + }, + ) + cache_path = Path(self._tmp_dir.name) / "entropy_cache.json" + save_tensors_data(td, cache_path) + + loaded = load_tensors_data(cache_path) + self.assertEqual(loaded.calibration_method, CalibrationMethod.Entropy) + lo, hi = loaded["y"].range_value + np.testing.assert_array_almost_equal(lo, np.array(-0.5, dtype=np.float32)) + np.testing.assert_array_almost_equal(hi, np.array(0.5, dtype=np.float32)) + np.testing.assert_array_almost_equal(loaded["y"].hist, hist) + np.testing.assert_array_almost_equal(loaded["y"].hist_edges, hist_edges) + + def test_load_tensors_data_invalid_path(self): + with self.assertRaises(FileNotFoundError): + load_tensors_data("/nonexistent/path/cache.json") + + def test_quantize_static_calibration_cache_path(self): + model_path = Path(self._tmp_dir.name) / "tiny_model.onnx" + self._make_simple_model(str(model_path)) + + cache_path = Path(self._tmp_dir.name) / "quant_cache.json" + out1_path = Path(self._tmp_dir.name) / "quantized1.onnx" + out2_path = Path(self._tmp_dir.name) / "quantized2.onnx" + + # First call: calibration_data_reader provided, cache written + data_reader = TestDataReader() + quantize_static( + str(model_path), + str(out1_path), + calibration_data_reader=data_reader, + calibration_cache_path=cache_path, + ) + self.assertTrue(cache_path.exists()) + td1 = load_tensors_data(cache_path) + + # Second call: no data_reader, load from cache + quantize_static( + str(model_path), + str(out2_path), + calibration_data_reader=None, + calibration_cache_path=cache_path, + ) + self.assertTrue(out2_path.exists()) + td2 = load_tensors_data(cache_path) + self.assertEqual(td1.calibration_method, td2.calibration_method) + + def test_quantize_static_no_reader_no_cache_raises(self): + model_path = Path(self._tmp_dir.name) / "tiny_model2.onnx" + self._make_simple_model(str(model_path)) + out_path = Path(self._tmp_dir.name) / "quantized_err.onnx" + + with self.assertRaises(ValueError): + quantize_static(str(model_path), str(out_path), calibration_data_reader=None) + + def test_save_tensors_data_creates_parent_dir(self): + nested_path = Path(self._tmp_dir.name) / "nested" / "dir" / "cache.json" + td = TensorsData( + CalibrationMethod.MinMax, + {"x": TensorData(lowest=np.array(-1.0, dtype=np.float32), highest=np.array(1.0, dtype=np.float32))}, + ) + save_tensors_data(td, nested_path) + self.assertTrue(nested_path.exists()) + + def test_save_tensors_data_handles_scalar_bins(self): + td = TensorsData( + CalibrationMethod.Entropy, + { + "z": TensorData( + lowest=np.array(0.0, dtype=np.float32), + highest=np.array(1.0, dtype=np.float32), + hist=np.array([1, 2], dtype=np.int64), + bins=np.int64(5), + ) + }, + ) + cache_path = Path(self._tmp_dir.name) / "scalar_bins_cache.json" + save_tensors_data(td, cache_path) + loaded = load_tensors_data(cache_path) + self.assertEqual(loaded["z"].bins, 5) + + def test_load_tensors_data_method_mismatch_raises(self): + model_path = Path(self._tmp_dir.name) / "tiny_mismatch.onnx" + self._make_simple_model(str(model_path)) + cache_path = Path(self._tmp_dir.name) / "mismatch_cache.json" + out_path = Path(self._tmp_dir.name) / "quantized_mismatch.onnx" + + data_reader = TestDataReader() + quantize_static( + str(model_path), + str(out_path), + calibration_data_reader=data_reader, + calibrate_method=CalibrationMethod.MinMax, + calibration_cache_path=cache_path, + ) + + with self.assertRaises(ValueError): + quantize_static( + str(model_path), + str(out_path), + calibration_data_reader=None, + calibrate_method=CalibrationMethod.Entropy, + calibration_cache_path=cache_path, + ) + + if __name__ == "__main__": unittest.main()