You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Quantization tool: Use nanmin, nanmax, nanmean in calibrator (#23749)
### Description
- The calibrator uses `np.max/np.min` to get min/max values from
collected data. However, these functions return `nan` if any of the
array values is `nan` which subsequently leads invalid scale and failure
during quantization at
https://github.com/microsoft/onnxruntime/blob/93689c5995dcacbb99c3afa9ec477b305c71159f/onnxruntime/python/tools/quantization/quant_utils.py#L293.
- When quantizing models with `GroupQueryAttention`, the intermediate
activations corresponding to padded tokens can become nan. We can safely
ignore such values as they don't contribute to the final model output.
- Using `np.nanmax/np.nanmin` ensures that the calibrator can handle
`nan` values. If all values are nan, numpy raises a `RuntimeWarning:
All-NaN slice encountered` warning which can help debug the eventual
scale issue failure.
```python
import numpy as np
no_nans = np.array([1, 2, 3], dtype=np.float32)
some_nans = np.array([np.nan, 1, 2, 3, np.nan, np.nan], dtype=np.float32)
all_nans = np.array([np.nan, np.nan], dtype=np.float32)
for array in [no_nans, some_nans, all_nans]:
print("np.max/np.min:", np.max(array), np.min(array))
print("np.nanmax/np.nanmin:", np.nanmax(array), np.nanmin(array))
```
Output
```bash
np.max/np.min: 3.0 1.0
np.nanmax/np.nanmin: 3.0 1.0
np.max/np.min: nan nan
np.nanmax/np.nanmin: 3.0 1.0
np.max/np.min: nan nan
np.nanmax/np.nanmin: nan nan
RuntimeWarning: All-NaN slice encountered
print("np.nanmax/np.nanmin:", np.nanmax(array), np.nanmin(array))
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
0 commit comments