Skip to content

Commit ca6afce

Browse files
hyunn9973copybara-github
authored andcommitted
Add Cosine Similarity metric.
PiperOrigin-RevId: 860610896
1 parent 6506745 commit ca6afce

File tree

6 files changed

+100
-0
lines changed

6 files changed

+100
-0
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Average = base.Average
2727
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
2828
BLEU = nlp_metrics.BLEU
29+
CosineSimilarity = image_metrics.CosineSimilarity
2930
DCGAtK = ranking_metrics.DCGAtK
3031
Dice = image_metrics.Dice
3132
FBetaScore = classification_metrics.FBetaScore
@@ -57,6 +58,7 @@
5758
"Average",
5859
"AveragePrecisionAtK",
5960
"BLEU",
61+
"CosineSimilarity",
6062
"DCGAtK",
6163
"Dice",
6264
"FBetaScore",

src/metrax/image_metrics.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,43 @@ def compute(self) -> jax.Array:
666666
"""Returns the final Dice coefficient."""
667667
epsilon = 1e-7
668668
return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon)
669+
670+
671+
@flax.struct.dataclass
672+
class CosineSimilarity(base.Average):
673+
r"""Calculates the Cosine Similarity between two arrays.
674+
675+
The Cosine Similarity is defined as the dot product of the vectors divided
676+
by the product of their magnitudes (norms).
677+
678+
.. math::
679+
cos_{sim}(x,y) = \frac{x \cdot y}{||x|| * ||y||}
680+
"""
681+
682+
@classmethod
683+
def from_model_output(
684+
cls,
685+
predictions: jax.Array,
686+
targets: jax.Array,
687+
axis: int = -1,
688+
) -> 'CosineSimilarity':
689+
"""Creates a CosineSimilarity instance.
690+
691+
Args:
692+
predictions: A floating point array of the predictions. The shape should
693+
be (batch_size,).
694+
targets: A floating point array of the targets. The shape should be
695+
(batch_size,).
696+
axis: The axis to compute the norm over.
697+
698+
Returns:
699+
A `CosineSimilarity` instance.
700+
"""
701+
dot_product = (predictions * targets).sum(axis=axis)
702+
predictions_norm = jnp.linalg.norm(predictions, ord=2, axis=axis)
703+
targets_norm = jnp.linalg.norm(targets, ord=2, axis=axis)
704+
705+
cosine_similarity = dot_product / (predictions_norm * targets_norm)
706+
707+
return super().from_model_output(values=cosine_similarity)
708+

src/metrax/image_metrics_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,47 @@ def test_dice(self, y_true, y_pred):
553553

554554
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)
555555

556+
@parameterized.named_parameters(
557+
(
558+
'cosine_similarity_basic_f32',
559+
PREDS_1,
560+
TARGETS_1,
561+
),
562+
(
563+
'cosine_similarity_multichannel_norm',
564+
PREDS_2,
565+
TARGETS_2,
566+
),
567+
(
568+
'cosine_similarity_uint8_range_single_channel',
569+
PREDS_3,
570+
TARGETS_3,
571+
),
572+
(
573+
'cosine_similarity_identical_images',
574+
PREDS_4,
575+
TARGETS_4,
576+
),
577+
(
578+
'cosine_similarity_large_batch',
579+
PREDS_6,
580+
TARGETS_6,
581+
),
582+
)
583+
def test_cosine_similarity_against_keras(self, predictions, targets):
584+
"""Test that CosineSimilarity computes expected values."""
585+
predictions = jnp.array(predictions)
586+
targets = jnp.array(targets)
587+
keras_cosine_similarity_metric = keras.metrics.CosineSimilarity()
588+
keras_cosine_similarity_metric.update_state(predictions, targets)
589+
expected = keras_cosine_similarity_metric.result()
590+
591+
metric = metrax.CosineSimilarity.from_model_output(
592+
predictions=predictions, targets=targets
593+
)
594+
result = metric.compute()
595+
596+
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)
556597

557598
if __name__ == '__main__':
558599
absltest.main()

src/metrax/metrax_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ class MetraxTest(parameterized.TestCase):
9292
'ks': KS,
9393
},
9494
),
95+
(
96+
'cosinesimilarity',
97+
metrax.CosineSimilarity,
98+
{
99+
'predictions': OUTPUT_LABELS,
100+
'targets': OUTPUT_PREDS,
101+
},
102+
),
95103
(
96104
'dcgAtK',
97105
metrax.DCGAtK,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Average = nnx_metrics.Average
2121
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
2222
BLEU = nnx_metrics.BLEU
23+
CosineSimilarity = nnx_metrics.CosineSimilarity
2324
DCGAtK = nnx_metrics.DCGAtK
2425
Dice = nnx_metrics.Dice
2526
FBetaScore = nnx_metrics.FBetaScore
@@ -50,6 +51,7 @@
5051
"Average",
5152
"AveragePrecisionAtK",
5253
"BLEU",
54+
"CosineSimilarity",
5355
"DCGAtK",
5456
"Dice",
5557
"FBetaScore",

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def __init__(self):
6060
super().__init__(metrax.BLEU)
6161

6262

63+
class CosineSimilarity(NnxWrapper):
64+
"""An NNX class for the Metrax metric CosineSimilarity."""
65+
66+
def __init__(self):
67+
super().__init__(metrax.CosineSimilarity)
68+
69+
6370
class DCGAtK(NnxWrapper):
6471
"""An NNX class for the Metrax metric DCGAtK."""
6572

0 commit comments

Comments
 (0)