diff --git a/src/nncf/common/tensor_statistics/collectors.py b/src/nncf/common/tensor_statistics/collectors.py index 2ee9cd362b9..a3666b6718e 100644 --- a/src/nncf/common/tensor_statistics/collectors.py +++ b/src/nncf/common/tensor_statistics/collectors.py @@ -928,9 +928,13 @@ def __init__(self, num_samples: int | None = None): def _register_reduced_input_impl(self, x: Tensor) -> None: trace = fns.sum(fns.multiply(x, x)) - # NOTE: average trace?? divide by number of diagonal elements - # TODO(dlyakhov): revise this formula as possibly it is with an error; adopted from previous HAWQ implementation - self._container = (self._container + trace) / x.size + # We normalize the trace by the number of elements in the tensor so that larger matrices do not dominate + # the sensitivity scores + self._container += trace / x.size + + def reset(self) -> None: + self._collected_samples = 0 + self._container = Tensor(0.0) def _aggregate_impl(self) -> Tensor: return self._container * 2 / self._collected_samples diff --git a/tests/common/test_reducers_and_aggregators.py b/tests/common/test_reducers_and_aggregators.py index 53b4fa98a6f..b0912751393 100644 --- a/tests/common/test_reducers_and_aggregators.py +++ b/tests/common/test_reducers_and_aggregators.py @@ -714,8 +714,8 @@ def test_aggregators_hash(self, aggregator_cls): HAWQ_AGGREGATOR_REFERENCE_VALUES = [ ([np.arange(10)], 57.0), - ([np.arange(12).reshape((2, 6)), np.arange(24).reshape((4, 6))], 181.92361111111111), - ([np.arange(8 * i).reshape((1, 8, i)) for i in range(1, 5)], 165.61627197265625), + ([np.arange(12).reshape((2, 6)), np.arange(24).reshape((4, 6))], 222.33333333333331), + ([np.arange(8 * i).reshape((1, 8, i)) for i in range(1, 5)], 300.3333333333333), ] @pytest.mark.parametrize("inputs,reference_output", HAWQ_AGGREGATOR_REFERENCE_VALUES)