@@ -277,6 +277,105 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2):
277
277
278
278
return not_selected_op1_nodes
279
279
280
+ def custom_write_calibration_table (calibration_cache , dir = "." ):
281
+ """
282
+ Helper function to write calibration table to files.
283
+ """
284
+
285
+ import json
286
+ import logging
287
+ import flatbuffers
288
+ import numpy as np
289
+
290
+ import onnxruntime .quantization .CalTableFlatBuffers .KeyValue as KeyValue
291
+ import onnxruntime .quantization .CalTableFlatBuffers .TrtTable as TrtTable
292
+ from onnxruntime .quantization .calibrate import CalibrationMethod , TensorData , TensorsData
293
+
294
+ logging .info (f"calibration cache: { calibration_cache } " )
295
+
296
+ class MyEncoder (json .JSONEncoder ):
297
+ def default (self , obj ):
298
+ if isinstance (obj , (TensorData , TensorsData )):
299
+ return obj .to_dict ()
300
+ if isinstance (obj , TensorDataWrapper ):
301
+ return obj .data_dict
302
+ if isinstance (obj , np .ndarray ):
303
+ return {"data" : obj .tolist (), "dtype" : str (obj .dtype ), "CLS" : "numpy.array" }
304
+ if isinstance (obj , CalibrationMethod ):
305
+ return {"CLS" : obj .__class__ .__name__ , "value" : str (obj )}
306
+ return json .JSONEncoder .default (self , obj )
307
+
308
+ json_data = json .dumps (calibration_cache , cls = MyEncoder )
309
+
310
+ with open (os .path .join (dir , "calibration.json" ), "w" ) as file :
311
+ file .write (json_data ) # use `json.loads` to do the reverse
312
+
313
+ # Serialize data using FlatBuffers
314
+ zero = np .array (0 )
315
+ builder = flatbuffers .Builder (1024 )
316
+ key_value_list = []
317
+
318
+ for key in sorted (calibration_cache .keys ()):
319
+ values = calibration_cache [key ]
320
+ d_values = values .to_dict ()
321
+
322
+ highest = d_values .get ("highest" , zero )
323
+ lowest = d_values .get ("lowest" , zero )
324
+
325
+ highest_val = highest .item () if hasattr (highest , "item" ) else float (highest )
326
+ lowest_val = lowest .item () if hasattr (lowest , "item" ) else float (lowest )
327
+
328
+ floats = [float (highest_val ), float (lowest_val )]
329
+
330
+ value = str (max (floats ))
331
+
332
+ flat_key = builder .CreateString (key )
333
+ flat_value = builder .CreateString (value )
334
+
335
+ KeyValue .KeyValueStart (builder )
336
+ KeyValue .KeyValueAddKey (builder , flat_key )
337
+ KeyValue .KeyValueAddValue (builder , flat_value )
338
+ key_value = KeyValue .KeyValueEnd (builder )
339
+
340
+ key_value_list .append (key_value )
341
+
342
+
343
+ TrtTable .TrtTableStartDictVector (builder , len (key_value_list ))
344
+ for key_value in key_value_list :
345
+ builder .PrependUOffsetTRelative (key_value )
346
+ main_dict = builder .EndVector ()
347
+
348
+ TrtTable .TrtTableStart (builder )
349
+ TrtTable .TrtTableAddDict (builder , main_dict )
350
+ cal_table = TrtTable .TrtTableEnd (builder )
351
+
352
+ builder .Finish (cal_table )
353
+ buf = builder .Output ()
354
+
355
+ with open (os .path .join (dir , "calibration.flatbuffers" ), "wb" ) as file :
356
+ file .write (buf )
357
+
358
+ # Deserialize data (for validation)
359
+ if os .environ .get ("QUANTIZATION_DEBUG" , 0 ) in (1 , "1" ):
360
+ cal_table = TrtTable .TrtTable .GetRootAsTrtTable (buf , 0 )
361
+ dict_len = cal_table .DictLength ()
362
+ for i in range (dict_len ):
363
+ key_value = cal_table .Dict (i )
364
+ logging .info (key_value .Key ())
365
+ logging .info (key_value .Value ())
366
+
367
+ # write plain text
368
+ with open (os .path .join (dir , "calibration.cache" ), "w" ) as file :
369
+ for key in sorted (calibration_cache .keys ()):
370
+ values = calibration_cache [key ]
371
+ d_values = values .to_dict ()
372
+ floats = [
373
+ float (d_values .get ("highest" , zero ).item ()),
374
+ float (d_values .get ("lowest" , zero ).item ()),
375
+ ]
376
+ value = key + " " + str (max (floats ))
377
+ file .write (value )
378
+ file .write ("\n " )
280
379
281
380
def parse_input_args ():
282
381
parser = argparse .ArgumentParser ()
@@ -553,8 +652,42 @@ def output_run_config(flags, samples):
553
652
for k , v in compute_range .data .items ():
554
653
json_compute_range [k ] = (float (v .range_value [0 ]), float (v .range_value [1 ]))
555
654
655
+ print ("Writing calibration table" )
656
+ try :
657
+ write_calibration_table (json_compute_range )
658
+ except AttributeError as e :
659
+ class TensorDataWrapper :
660
+ def __init__ (self , data_dict ):
661
+ self .data_dict = data_dict
662
+
663
+ def to_dict (self ):
664
+ return self .data_dict
665
+
666
+ def __repr__ (self ):
667
+ return repr (self .data_dict )
668
+
669
+ def __serializable__ (self ):
670
+ return self .data_dict
671
+
672
+ calibration_data = {}
673
+ for k , v in compute_range .data .items ():
674
+ if hasattr (v , 'to_dict' ):
675
+ tensor_dict = v .to_dict ()
676
+ processed_dict = {}
677
+ for dk , dv in tensor_dict .items ():
678
+ if isinstance (dv , np .ndarray ):
679
+ processed_dict [dk ] = dv .item () if dv .size == 1 else dv .tolist ()
680
+ elif isinstance (dv , np .number ):
681
+ processed_dict [dk ] = dv .item ()
682
+ else :
683
+ processed_dict [dk ] = dv
684
+ calibration_data [k ] = TensorDataWrapper (processed_dict )
685
+ else :
686
+ calibration_data [k ] = v
687
+
688
+ print ("Using custom calibration table function" )
689
+ custom_write_calibration_table (calibration_data )
556
690
557
- write_calibration_table (json_compute_range )
558
691
print ("Calibration is done. Calibration cache is saved to calibration.json" )
559
692
560
693
model_quants = model_quants + "_int8"
0 commit comments