Skip to content

Commit 358d898

Browse files
jshin1394copybara-github
authored andcommitted
[MetraxAudit] Audit fixes for FBetaScore and Dice metrics.
Classification Metrics: - Fix NameError in FBetaScore.compute() (beta -> self.beta). - Improve numerical stability in FBetaScore using divide_no_nan. Image Metrics: - Improve numerical stability in Dice metric using divide_no_nan. - Update tests to reflect stability improvements. PiperOrigin-RevId: 864705338
1 parent 8be27a9 commit 358d898

File tree

4 files changed

+205
-167
lines changed

4 files changed

+205
-167
lines changed

src/metrax/classification_metrics.py

Lines changed: 116 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,10 @@ def compute(self) -> jax.Array:
454454

455455
@flax.struct.dataclass
456456
class AUCROC(clu_metrics.Metric):
457-
r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions`
458-
and `labels`.
457+
r"""Computes Area Under the ROC curve (AUC-ROC).
458+
459+
Computes area under the receiver operation characteristic curve for binary
460+
classification given `predictions` and `labels`.
459461
460462
The ROC curve shows the tradeoff between the true positive rate (TPR) and
461463
false positive
@@ -579,147 +581,120 @@ def compute(self) -> jax.Array:
579581
# Threshold goes from 0 to 1, so trapezoid is negative.
580582
return jnp.trapezoid(tp_rate, fp_rate) * -1
581583

584+
582585
@flax.struct.dataclass
583586
class FBetaScore(clu_metrics.Metric):
584-
"""
585-
F-Beta score Metric class
586-
Computes the F-Beta score for the binary classification given 'predictions' and 'labels'.
587-
588-
Formula for F-Beta Score:
589-
b2 = beta ** 2
590-
f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall)
591-
592-
F-Beta turns into the F1 Score when beta = 1.0
593-
594-
Attributes:
595-
true_positives: The count of true positive instances from the given data,
596-
label, and threshold.
597-
false_positives: The count of false positive instances from the given data,
598-
label, and threshold.
599-
false_negatives: The count of false negative instances from the given data,
600-
label, and threshold.
601-
beta: The beta value used in the F-Score metric
602-
"""
587+
"""F-Beta score Metric class.
603588
604-
true_positives: jax.Array
605-
false_positives: jax.Array
606-
false_negatives: jax.Array
607-
beta: float = 1.0
608-
609-
# Reset the variables for the class
610-
@classmethod
611-
def empty(cls) -> 'FBetaScore':
612-
return cls(
613-
true_positives = jnp.array(0, jnp.float32),
614-
false_positives = jnp.array(0, jnp.float32),
615-
false_negatives = jnp.array(0, jnp.float32),
616-
beta=1.0,
617-
)
618-
619-
@classmethod
620-
def from_model_output(
621-
cls,
622-
predictions: jax.Array,
623-
labels: jax.Array,
624-
beta = beta,
625-
threshold = 0.5,) -> 'FBetaScore':
626-
"""Updates the metric.
627-
Note: When only predictions and labels are given, the score calculated
628-
is the F1 score if the FBetaScore beta value has not been previously modified.
629-
630-
Args:
631-
predictions: A floating point 1D vector whose values are in the range [0,
632-
1]. The shape should be (batch_size,).
633-
labels: True value. The value is expected to be 0 or 1. The shape should
634-
be (batch_size,).
635-
beta: beta value to use in the F-Score metric. A floating number.
636-
threshold: threshold value to use in the F-Score metric. A floating number.
637-
638-
Returns:
639-
The updated FBetaScore object.
640-
641-
Raises:
642-
ValueError: If type of `labels` is wrong or the shapes of `predictions`
643-
and `labels` are incompatible. If the beta or threshold are invalid values
644-
an error is raised as well.
645-
"""
646-
647-
# Ensure that beta is a floating number and a valid number
648-
if not isinstance(beta, float):
649-
raise ValueError('The "Beta" value must be a floating number.')
650-
if beta <= 0.0:
651-
raise ValueError('The "Beta" value must be larger than 0.0.')
652-
653-
# Ensure that threshold is a floating number and a valid number
654-
if not isinstance(threshold, float):
655-
raise ValueError('The "Threshold" value must be a floating number.')
656-
if threshold < 0.0 or threshold > 1.0:
657-
raise ValueError('The "Threshold" value must be between 0 and 1.')
658-
659-
# Modify predictions with the given threshold value
660-
predictions = jnp.where(predictions >= threshold, 1, 0)
661-
662-
# Assign the true_positive, false_positive, and false_negative their values
663-
"""
664-
We are calculating these values manually instead of using Metrax's
665-
precision and recall classes. This is because the Metrax versions end up
666-
outputting a single numerical answer when we need an array of numbers.
667-
"""
668-
true_positives = jnp.sum(predictions * labels, axis = 0)
669-
false_positives = jnp.sum(predictions * (1 - labels), axis = 0)
670-
false_negatives = jnp.sum((1- predictions) * labels, axis = 0)
671-
672-
return cls(true_positives = true_positives,
673-
false_positives = false_positives,
674-
false_negatives = false_negatives,
675-
beta = beta)
589+
Computes the F-Beta score for the binary classification given 'predictions'
590+
and 'labels'.
676591
592+
Formula for F-Beta Score:
593+
b2 = beta ** 2
594+
f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 +
595+
recall)
596+
597+
F-Beta turns into the F1 Score when beta = 1.0
598+
599+
Attributes:
600+
true_positives: The count of true positive instances from the given data,
601+
label, and threshold.
602+
false_positives: The count of false positive instances from the given data,
603+
label, and threshold.
604+
false_negatives: The count of false negative instances from the given data,
605+
label, and threshold.
606+
beta: The beta value used in the F-Score metric.
607+
"""
608+
609+
true_positives: jax.Array
610+
false_positives: jax.Array
611+
false_negatives: jax.Array
612+
beta: float = 1.0
613+
614+
# Reset the variables for the class
615+
@classmethod
616+
def empty(cls) -> 'FBetaScore':
617+
return cls(
618+
true_positives=jnp.array(0, jnp.float32),
619+
false_positives=jnp.array(0, jnp.float32),
620+
false_negatives=jnp.array(0, jnp.float32),
621+
beta=1.0,
622+
)
623+
624+
@classmethod
625+
def from_model_output(
626+
cls,
627+
predictions: jax.Array,
628+
labels: jax.Array,
629+
beta: float = 1.0,
630+
threshold: float = 0.5,
631+
) -> 'FBetaScore':
632+
"""Updates the metric.
633+
634+
Note: When only predictions and labels are given, the score calculated is
635+
the F1 score if the FBetaScore beta value has not been previously modified.
636+
637+
Args:
638+
predictions: A floating point 1D vector whose values are in the range
639+
[0, 1]. The shape should be (batch_size,).
640+
labels: True value. The value is expected to be 0 or 1. The shape should
641+
be (batch_size,).
642+
beta: beta value to use in the F-Score metric. A floating number.
643+
threshold: threshold value to use in the F-Score metric. A floating
644+
number.
645+
646+
Returns:
647+
The updated FBetaScore object.
648+
649+
Raises:
650+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
651+
and `labels` are incompatible. If the beta or threshold are invalid
652+
values
653+
an error is raised as well.
677654
"""
678-
This function is currently unused as the 'from_model_output' function can handle the whole
679-
dataset without needing to split and merge them. I'm leaving this here for now incase we want to
680-
repurpose this or need to change something that requires this function's use again. This function would need
681-
to be reworked for it to work with the current implementation of this class.
682-
"""
683-
# # Merge datasets together
684-
# def merge(self, other: 'FBetaScore') -> 'FBetaScore':
685-
#
686-
# # Check if the incoming beta is the same value as the current beta
687-
# if other.beta == self.beta:
688-
# return type(self)(
689-
# true_positives = self.true_positives + other.true_positives,
690-
# false_positives = self.false_positives + other.false_positives,
691-
# false_negatives = self.false_negatives + other.false_negatives,
692-
# beta=self.beta,
693-
# )
694-
# else:
695-
# raise ValueError('The "Beta" values between the two are not equal.')
696-
697-
# Compute the F-Beta score metric
698-
def compute(self) -> jax.Array:
699-
700-
# Epsilon fuz factor required to match with the keras version
701-
epsilon = 1e-7
702-
703-
# Manually calculate precision and recall
704-
"""
705-
This is done in this manner since the metrax variants of precision
706-
and recall output only single numbers. To match the keras value
707-
we need an array of numbers to work with.
708-
"""
709-
precision = jnp.divide(
710-
self.true_positives,
711-
self.true_positives + self.false_positives + epsilon,
712-
)
713-
recall = jnp.divide(
714-
self.true_positives,
715-
self.true_positives + self.false_negatives + epsilon,
716-
)
717-
718-
# Compute the numerator and denominator of the F-Score formula
719-
b2 = self.beta ** 2
720-
numerator = (1 + b2) * (precision * recall)
721-
denominator = (b2 * precision) + recall
722-
723-
return base.divide_no_nan(
724-
numerator, denominator + epsilon,
725-
)
655+
656+
# Ensure that beta is a floating number and a valid number
657+
if not isinstance(beta, float):
658+
raise ValueError('The "Beta" value must be a floating number.')
659+
if beta <= 0.0:
660+
raise ValueError('The "Beta" value must be larger than 0.0.')
661+
662+
# Ensure that threshold is a floating number and a valid number
663+
if not isinstance(threshold, float):
664+
raise ValueError('The "Threshold" value must be a floating number.')
665+
if threshold < 0.0 or threshold > 1.0:
666+
raise ValueError('The "Threshold" value must be between 0 and 1.')
667+
668+
# Modify predictions with the given threshold value
669+
predictions = jnp.where(predictions >= threshold, 1, 0)
670+
671+
# Assign the true_positive, false_positive, and false_negative their values
672+
true_positives = jnp.sum(predictions * labels, axis=0)
673+
false_positives = jnp.sum(predictions * (1 - labels), axis=0)
674+
false_negatives = jnp.sum((1 - predictions) * labels, axis=0)
675+
676+
return cls(
677+
true_positives=true_positives,
678+
false_positives=false_positives,
679+
false_negatives=false_negatives,
680+
beta=beta,
681+
)
682+
683+
# Compute the F-Beta score metric
684+
def compute(self) -> jax.Array:
685+
686+
# Compute the numerator and denominator of the F-Score formula
687+
precision = base.divide_no_nan(
688+
self.true_positives,
689+
self.true_positives + self.false_positives,
690+
)
691+
recall = base.divide_no_nan(
692+
self.true_positives,
693+
self.true_positives + self.false_negatives,
694+
)
695+
696+
b2 = self.beta**2
697+
numerator = (1 + b2) * (precision * recall)
698+
denominator = (b2 * precision) + recall
699+
700+
return base.divide_no_nan(numerator, denominator)

src/metrax/classification_metrics_test.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
"""Tests for metrax classification metrics."""
1616

1717
import os
18-
os.environ['KERAS_BACKEND'] = 'jax'
1918

20-
from absl.testing import absltest
19+
os.environ['KERAS_BACKEND'] = 'jax' # pylint: disable=g-import-not-at-top
20+
21+
from absl.testing import absltest # pylint: disable=g-import-not-at-top
2122
from absl.testing import parameterized
2223
import jax.numpy as jnp
2324
import keras
@@ -47,6 +48,7 @@
4748
(BATCHES, 1),
4849
).astype(np.float32)
4950

51+
5052
class ClassificationMetricsTest(parameterized.TestCase):
5153

5254
def test_precision_empty(self):
@@ -273,17 +275,53 @@ def test_aucroc(self, inputs, dtype):
273275
atol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-7,
274276
)
275277

276-
# Testing function for F-Beta classification
278+
# Testing function for F-Beta classification
277279
# name, output true, output prediction, threshold, beta
278280
@parameterized.named_parameters(
279281
('basic_f16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0),
280282
('basic_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0),
281-
('low_threshold_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0),
282-
('high_threshold_bf16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0),
283-
('batch_size_one_beta_1.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0),
284-
('high_threshold_f16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0),
285-
('high_threshold_f32_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0),
286-
('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0),
283+
(
284+
'low_threshold_f32_beta_1.0',
285+
OUTPUT_LABELS,
286+
OUTPUT_PREDS_F32,
287+
0.1,
288+
1.0,
289+
),
290+
(
291+
'high_threshold_bf16_beta_1.0',
292+
OUTPUT_LABELS,
293+
OUTPUT_PREDS_BF16,
294+
0.7,
295+
1.0,
296+
),
297+
(
298+
'batch_size_one_beta_1.0',
299+
OUTPUT_LABELS_BS1,
300+
OUTPUT_PREDS_BS1,
301+
0.5,
302+
1.0,
303+
),
304+
(
305+
'high_threshold_f16_beta_2.0',
306+
OUTPUT_LABELS,
307+
OUTPUT_PREDS_F16,
308+
0.7,
309+
2.0,
310+
),
311+
(
312+
'high_threshold_f32_beta_2.0',
313+
OUTPUT_LABELS,
314+
OUTPUT_PREDS_F32,
315+
0.7,
316+
2.0,
317+
),
318+
(
319+
'low_threshold_bf16_beta_2.0',
320+
OUTPUT_LABELS,
321+
OUTPUT_PREDS_BF16,
322+
0.1,
323+
2.0,
324+
),
287325
('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0),
288326
('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0),
289327
)
@@ -307,6 +345,26 @@ def test_fbetascore(self, y_true, y_pred, threshold, beta):
307345
atol=atol,
308346
)
309347

348+
def test_fbeta_default_init(self):
349+
"""Test usage of FBetaScore with default beta (should be 1.0)."""
350+
y_true = jnp.array([1, 0, 1, 0], dtype=jnp.float32)
351+
y_pred = jnp.array([1, 0, 0, 1], dtype=jnp.float32)
352+
353+
# Init without beta argument
354+
metric = metrax.FBetaScore.from_model_output(
355+
predictions=y_pred,
356+
labels=y_true,
357+
)
358+
359+
# Should default to beta=1.0 (F1 Score)
360+
self.assertEqual(metric.beta, 1.0)
361+
362+
# Calculate expected F1
363+
# TP=1 (idx 0), FP=1 (idx 3), FN=1 (idx 2)
364+
# Precision = 1/2, Recall = 1/2
365+
# F1 = 2 * (0.5 * 0.5) / (0.5 + 0.5) = 0.5
366+
self.assertAlmostEqual(metric.compute(), 0.5)
367+
310368

311369
if __name__ == '__main__':
312370
absltest.main()

0 commit comments

Comments
 (0)