Skip to content

Commit df1fd96

Browse files
jambaykguschmue
authored andcommitted
Quantization tool: Use nanmin, nanmax, nanmean in calibrator (#23749)
### Description - The calibrator uses `np.max/np.min` to get min/max values from collected data. However, these functions return `nan` if any of the array values is `nan` which subsequently leads invalid scale and failure during quantization at https://github.com/microsoft/onnxruntime/blob/93689c5995dcacbb99c3afa9ec477b305c71159f/onnxruntime/python/tools/quantization/quant_utils.py#L293. - When quantizing models with `GroupQueryAttention`, the intermediate activations corresponding to padded tokens can become nan. We can safely ignore such values as they don't contribute to the final model output. - Using `np.nanmax/np.nanmin` ensures that the calibrator can handle `nan` values. If all values are nan, numpy raises a `RuntimeWarning: All-NaN slice encountered` warning which can help debug the eventual scale issue failure. ```python import numpy as np no_nans = np.array([1, 2, 3], dtype=np.float32) some_nans = np.array([np.nan, 1, 2, 3, np.nan, np.nan], dtype=np.float32) all_nans = np.array([np.nan, np.nan], dtype=np.float32) for array in [no_nans, some_nans, all_nans]: print("np.max/np.min:", np.max(array), np.min(array)) print("np.nanmax/np.nanmin:", np.nanmax(array), np.nanmin(array)) ``` Output ```bash np.max/np.min: 3.0 1.0 np.nanmax/np.nanmin: 3.0 1.0 np.max/np.min: nan nan np.nanmax/np.nanmin: 3.0 1.0 np.max/np.min: nan nan np.nanmax/np.nanmin: nan nan RuntimeWarning: All-NaN slice encountered print("np.nanmax/np.nanmin:", np.nanmax(array), np.nanmin(array)) ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent fcca2ca commit df1fd96

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

onnxruntime/python/tools/quantization/calibrate.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -496,14 +496,14 @@ def compute_data(self) -> TensorsData:
496496
pairs = []
497497
for i in range(0, len(added_output_names), 2):
498498
if self.moving_average:
499-
min_value_array = np.mean(merged_added_output_dict[added_output_names[i]], axis=0)
500-
max_value_array = np.mean(merged_added_output_dict[added_output_names[i + 1]], axis=0)
499+
min_value_array = np.nanmean(merged_added_output_dict[added_output_names[i]], axis=0)
500+
max_value_array = np.nanmean(merged_added_output_dict[added_output_names[i + 1]], axis=0)
501501
else:
502-
min_value_array = np.min(merged_added_output_dict[added_output_names[i]], axis=0)
503-
max_value_array = np.max(merged_added_output_dict[added_output_names[i + 1]], axis=0)
502+
min_value_array = np.nanmin(merged_added_output_dict[added_output_names[i]], axis=0)
503+
max_value_array = np.nanmax(merged_added_output_dict[added_output_names[i + 1]], axis=0)
504504

505505
if self.symmetric:
506-
max_absolute_value = np.max([np.abs(min_value_array), np.abs(max_value_array)], axis=0)
506+
max_absolute_value = np.nanmax([np.abs(min_value_array), np.abs(max_value_array)], axis=0)
507507
pairs.append((-max_absolute_value, max_absolute_value))
508508
else:
509509
pairs.append((min_value_array, max_value_array))
@@ -834,8 +834,8 @@ def collect_absolute_value(self, name_to_arr):
834834
data_arr_np = data_arr
835835
data_arr_np = data_arr_np.flatten()
836836
if data_arr_np.size > 0:
837-
min_value = np.min(data_arr_np)
838-
max_value = np.max(data_arr_np)
837+
min_value = np.nanmin(data_arr_np)
838+
max_value = np.nanmax(data_arr_np)
839839
else:
840840
min_value = np.array(0, dtype=data_arr_np.dtype)
841841
max_value = np.array(0, dtype=data_arr_np.dtype)
@@ -858,7 +858,7 @@ def collect_absolute_value(self, name_to_arr):
858858
assert hasattr(old_max, "dtype"), f"old_min should be a numpy array but is {type(old_max)}"
859859
old_hist = old_histogram[0]
860860
old_hist_edges = old_histogram[1]
861-
temp_amax = np.max(data_arr_np)
861+
temp_amax = np.nanmax(data_arr_np)
862862
if temp_amax > old_hist_edges[-1]:
863863
# increase the number of bins
864864
width = old_hist_edges[1] - old_hist_edges[0]
@@ -882,8 +882,8 @@ def collect_value(self, name_to_arr):
882882
data_arr = data_arr.flatten() # noqa: PLW2901
883883

884884
if data_arr.size > 0:
885-
min_value = np.min(data_arr)
886-
max_value = np.max(data_arr)
885+
min_value = np.nanmin(data_arr)
886+
max_value = np.nanmax(data_arr)
887887
else:
888888
min_value = np.array(0, dtype=data_arr.dtype)
889889
max_value = np.array(0, dtype=data_arr.dtype)

onnxruntime/python/tools/quantization/quant_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non
290290
dr = numpy.array(rmax - rmin, dtype=numpy.float64)
291291
dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64)
292292
scale = numpy.array(dr / dq)
293-
assert scale >= 0, "scale isse"
293+
assert scale >= 0, "scale issue"
294294
if scale < numpy.finfo(rmax.dtype).tiny:
295295
scale = numpy.array(1.0, dtype=rmax.dtype)
296296
zero_point = numpy.array(0, dtype=qmin.dtype)

0 commit comments

Comments
 (0)