@@ -553,6 +553,47 @@ def test_dice(self, y_true, y_pred):
553553
554554 np .testing .assert_allclose (result , expected , rtol = 1e-5 , atol = 1e-5 )
555555
556+ @parameterized .named_parameters (
557+ (
558+ 'cosine_similarity_basic_f32' ,
559+ PREDS_1 ,
560+ TARGETS_1 ,
561+ ),
562+ (
563+ 'cosine_similarity_multichannel_norm' ,
564+ PREDS_2 ,
565+ TARGETS_2 ,
566+ ),
567+ (
568+ 'cosine_similarity_uint8_range_single_channel' ,
569+ PREDS_3 ,
570+ TARGETS_3 ,
571+ ),
572+ (
573+ 'cosine_similarity_identical_images' ,
574+ PREDS_4 ,
575+ TARGETS_4 ,
576+ ),
577+ (
578+ 'cosine_similarity_large_batch' ,
579+ PREDS_6 ,
580+ TARGETS_6 ,
581+ ),
582+ )
583+ def test_cosine_similarity_against_keras (self , predictions , targets ):
584+ """Test that CosineSimilarity computes expected values."""
585+ predictions = jnp .array (predictions )
586+ targets = jnp .array (targets )
587+ keras_cosine_similarity_metric = keras .metrics .CosineSimilarity ()
588+ keras_cosine_similarity_metric .update_state (predictions , targets )
589+ expected = keras_cosine_similarity_metric .result ()
590+
591+ metric = metrax .CosineSimilarity .from_model_output (
592+ predictions = predictions , targets = targets
593+ )
594+ result = metric .compute ()
595+
596+ np .testing .assert_allclose (result , expected , rtol = 1e-5 , atol = 1e-5 )
556597
557598if __name__ == '__main__' :
558599 absltest .main ()
0 commit comments