diff --git a/Python/ml_metrics/average_precision.py b/Python/ml_metrics/average_precision.py index e18297d..c735735 100644 --- a/Python/ml_metrics/average_precision.py +++ b/Python/ml_metrics/average_precision.py @@ -22,6 +22,12 @@ def apk(actual, predicted, k=10): The average precision at k over the input lists """ + if actual is None or len(actual) == 0: + return 0.0 + + if predicted is None or len(predicted) == 0: + return 0.0 + if len(predicted)>k: predicted = predicted[:k] @@ -33,9 +39,6 @@ def apk(actual, predicted, k=10): num_hits += 1.0 score += num_hits / (i+1.0) - if not actual: - return 0.0 - return score / min(len(actual), k) def mapk(actual, predicted, k=10): @@ -48,7 +51,7 @@ def mapk(actual, predicted, k=10): Parameters ---------- actual : list - A list of lists of elements that are to be predicted + A list of lists of elements that are to be predicted (order doesn't matter in the lists) predicted : list A list of lists of predicted elements diff --git a/Python/ml_metrics/test/test_average_precision.py b/Python/ml_metrics/test/test_average_precision.py index d1cf761..2af7b57 100644 --- a/Python/ml_metrics/test/test_average_precision.py +++ b/Python/ml_metrics/test/test_average_precision.py @@ -9,9 +9,20 @@ class TestAveragePrecision(unittest.TestCase): def test_apk(self): self.assertAlmostEqual(metrics.apk(range(1,6),[6,4,7,1,2], 2), 0.25) self.assertAlmostEqual(metrics.apk(range(1,6),[1,1,1,1,1], 5), 0.2) - predicted = range(1,21) + predicted = list(range(1,21)) predicted.extend(range(200,600)) self.assertAlmostEqual(metrics.apk(range(1,100),predicted, 20), 1.0) + # numpy array test + self.assertAlmostEqual(metrics.apk(np.asarray(range(1,100)),predicted, 20), 1.0) + self.assertAlmostEqual(metrics.apk(range(1,100), np.asarray(predicted), 20), 1.0) + + + def test_apk_empties(self): + self.assertAlmostEqual(metrics.apk([], [1, 3], 20), 0.0) + self.assertAlmostEqual(metrics.apk(None, [1, 2], 20), 0.0) + self.assertAlmostEqual(metrics.apk([1, 3], [], 20), 0.0) + self.assertAlmostEqual(metrics.apk([1, 2], None, 20), 0.0) + def test_mapk(self): self.assertAlmostEqual(metrics.mapk([range(1,5)],[range(1,5)],3), 1.0)