Skip to content

Commit 84c7cc6

Browse files
jshin1394copybara-github
authored andcommitted
[MetraxAudit] Audit fixes for FBetaScore and Dice metrics.
Classification Metrics: - Fix NameError in FBetaScore.compute() (beta -> self.beta). - Improve numerical stability in FBetaScore using divide_no_nan. Image Metrics: - Improve numerical stability in Dice metric using divide_no_nan. - Update tests to reflect stability improvements. PiperOrigin-RevId: 864705338
1 parent 8be27a9 commit 84c7cc6

File tree

4 files changed

+138
-146
lines changed

4 files changed

+138
-146
lines changed

src/metrax/classification_metrics.py

Lines changed: 111 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -581,145 +581,117 @@ def compute(self) -> jax.Array:
581581

582582
@flax.struct.dataclass
583583
class FBetaScore(clu_metrics.Metric):
584-
"""
585-
F-Beta score Metric class
586-
Computes the F-Beta score for the binary classification given 'predictions' and 'labels'.
587-
588-
Formula for F-Beta Score:
589-
b2 = beta ** 2
590-
f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall)
591-
592-
F-Beta turns into the F1 Score when beta = 1.0
593-
594-
Attributes:
595-
true_positives: The count of true positive instances from the given data,
596-
label, and threshold.
597-
false_positives: The count of false positive instances from the given data,
598-
label, and threshold.
599-
false_negatives: The count of false negative instances from the given data,
600-
label, and threshold.
601-
beta: The beta value used in the F-Score metric
602-
"""
584+
"""F-Beta score Metric class.
603585
604-
true_positives: jax.Array
605-
false_positives: jax.Array
606-
false_negatives: jax.Array
607-
beta: float = 1.0
608-
609-
# Reset the variables for the class
610-
@classmethod
611-
def empty(cls) -> 'FBetaScore':
612-
return cls(
613-
true_positives = jnp.array(0, jnp.float32),
614-
false_positives = jnp.array(0, jnp.float32),
615-
false_negatives = jnp.array(0, jnp.float32),
616-
beta=1.0,
617-
)
618-
619-
@classmethod
620-
def from_model_output(
621-
cls,
622-
predictions: jax.Array,
623-
labels: jax.Array,
624-
beta = beta,
625-
threshold = 0.5,) -> 'FBetaScore':
626-
"""Updates the metric.
627-
Note: When only predictions and labels are given, the score calculated
628-
is the F1 score if the FBetaScore beta value has not been previously modified.
629-
630-
Args:
631-
predictions: A floating point 1D vector whose values are in the range [0,
632-
1]. The shape should be (batch_size,).
633-
labels: True value. The value is expected to be 0 or 1. The shape should
634-
be (batch_size,).
635-
beta: beta value to use in the F-Score metric. A floating number.
636-
threshold: threshold value to use in the F-Score metric. A floating number.
637-
638-
Returns:
639-
The updated FBetaScore object.
640-
641-
Raises:
642-
ValueError: If type of `labels` is wrong or the shapes of `predictions`
643-
and `labels` are incompatible. If the beta or threshold are invalid values
644-
an error is raised as well.
645-
"""
646-
647-
# Ensure that beta is a floating number and a valid number
648-
if not isinstance(beta, float):
649-
raise ValueError('The "Beta" value must be a floating number.')
650-
if beta <= 0.0:
651-
raise ValueError('The "Beta" value must be larger than 0.0.')
652-
653-
# Ensure that threshold is a floating number and a valid number
654-
if not isinstance(threshold, float):
655-
raise ValueError('The "Threshold" value must be a floating number.')
656-
if threshold < 0.0 or threshold > 1.0:
657-
raise ValueError('The "Threshold" value must be between 0 and 1.')
658-
659-
# Modify predictions with the given threshold value
660-
predictions = jnp.where(predictions >= threshold, 1, 0)
661-
662-
# Assign the true_positive, false_positive, and false_negative their values
663-
"""
664-
We are calculating these values manually instead of using Metrax's
665-
precision and recall classes. This is because the Metrax versions end up
666-
outputting a single numerical answer when we need an array of numbers.
667-
"""
668-
true_positives = jnp.sum(predictions * labels, axis = 0)
669-
false_positives = jnp.sum(predictions * (1 - labels), axis = 0)
670-
false_negatives = jnp.sum((1- predictions) * labels, axis = 0)
671-
672-
return cls(true_positives = true_positives,
673-
false_positives = false_positives,
674-
false_negatives = false_negatives,
675-
beta = beta)
586+
Computes the F-Beta score for the binary classification given 'predictions'
587+
and 'labels'.
676588
589+
Formula for F-Beta Score:
590+
b2 = beta ** 2
591+
f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 +
592+
recall)
593+
594+
F-Beta turns into the F1 Score when beta = 1.0
595+
596+
Attributes:
597+
true_positives: The count of true positive instances from the given data,
598+
label, and threshold.
599+
false_positives: The count of false positive instances from the given
600+
data, label, and threshold.
601+
false_negatives: The count of false negative instances from the given
602+
data, label, and threshold.
603+
beta: The beta value used in the F-Score metric
604+
"""
605+
606+
true_positives: jax.Array
607+
false_positives: jax.Array
608+
false_negatives: jax.Array
609+
beta: float = 1.0
610+
611+
# Reset the variables for the class
612+
@classmethod
613+
def empty(cls) -> 'FBetaScore':
614+
return cls(
615+
true_positives=jnp.array(0, jnp.float32),
616+
false_positives=jnp.array(0, jnp.float32),
617+
false_negatives=jnp.array(0, jnp.float32),
618+
beta=1.0,
619+
)
620+
621+
@classmethod
622+
def from_model_output(
623+
cls,
624+
predictions: jax.Array,
625+
labels: jax.Array,
626+
beta: float = 1.0,
627+
threshold: float = 0.5,
628+
) -> 'FBetaScore':
629+
"""Updates the metric.
630+
631+
Note: When only predictions and labels are given, the score calculated is
632+
the F1 score if the FBetaScore beta value has not been previously modified.
633+
634+
Args:
635+
predictions: A floating point 1D vector whose values are in the range
636+
[0, 1]. The shape should be (batch_size,).
637+
labels: True value. The value is expected to be 0 or 1. The shape should
638+
be (batch_size,).
639+
beta: beta value to use in the F-Score metric. A floating number.
640+
threshold: threshold value to use in the F-Score metric. A floating
641+
number.
642+
643+
Returns:
644+
The updated FBetaScore object.
645+
646+
Raises:
647+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
648+
and `labels` are incompatible. If the beta or threshold are invalid
649+
values
650+
an error is raised as well.
677651
"""
678-
This function is currently unused as the 'from_model_output' function can handle the whole
679-
dataset without needing to split and merge them. I'm leaving this here for now incase we want to
680-
repurpose this or need to change something that requires this function's use again. This function would need
681-
to be reworked for it to work with the current implementation of this class.
682-
"""
683-
# # Merge datasets together
684-
# def merge(self, other: 'FBetaScore') -> 'FBetaScore':
685-
#
686-
# # Check if the incoming beta is the same value as the current beta
687-
# if other.beta == self.beta:
688-
# return type(self)(
689-
# true_positives = self.true_positives + other.true_positives,
690-
# false_positives = self.false_positives + other.false_positives,
691-
# false_negatives = self.false_negatives + other.false_negatives,
692-
# beta=self.beta,
693-
# )
694-
# else:
695-
# raise ValueError('The "Beta" values between the two are not equal.')
696-
697-
# Compute the F-Beta score metric
698-
def compute(self) -> jax.Array:
699-
700-
# Epsilon fuz factor required to match with the keras version
701-
epsilon = 1e-7
702-
703-
# Manually calculate precision and recall
704-
"""
705-
This is done in this manner since the metrax variants of precision
706-
and recall output only single numbers. To match the keras value
707-
we need an array of numbers to work with.
708-
"""
709-
precision = jnp.divide(
710-
self.true_positives,
711-
self.true_positives + self.false_positives + epsilon,
712-
)
713-
recall = jnp.divide(
714-
self.true_positives,
715-
self.true_positives + self.false_negatives + epsilon,
716-
)
717-
718-
# Compute the numerator and denominator of the F-Score formula
719-
b2 = self.beta ** 2
720-
numerator = (1 + b2) * (precision * recall)
721-
denominator = (b2 * precision) + recall
722-
723-
return base.divide_no_nan(
724-
numerator, denominator + epsilon,
725-
)
652+
653+
# Ensure that beta is a floating number and a valid number
654+
if not isinstance(beta, float):
655+
raise ValueError('The "Beta" value must be a floating number.')
656+
if beta <= 0.0:
657+
raise ValueError('The "Beta" value must be larger than 0.0.')
658+
659+
# Ensure that threshold is a floating number and a valid number
660+
if not isinstance(threshold, float):
661+
raise ValueError('The "Threshold" value must be a floating number.')
662+
if threshold < 0.0 or threshold > 1.0:
663+
raise ValueError('The "Threshold" value must be between 0 and 1.')
664+
665+
# Modify predictions with the given threshold value
666+
predictions = jnp.where(predictions >= threshold, 1, 0)
667+
668+
# Assign the true_positive, false_positive, and false_negative their values
669+
true_positives = jnp.sum(predictions * labels, axis=0)
670+
false_positives = jnp.sum(predictions * (1 - labels), axis=0)
671+
false_negatives = jnp.sum((1 - predictions) * labels, axis=0)
672+
673+
return cls(
674+
true_positives=true_positives,
675+
false_positives=false_positives,
676+
false_negatives=false_negatives,
677+
beta=beta,
678+
)
679+
680+
# Compute the F-Beta score metric
681+
def compute(self) -> jax.Array:
682+
683+
# Compute the numerator and denominator of the F-Score formula
684+
precision = base.divide_no_nan(
685+
self.true_positives,
686+
self.true_positives + self.false_positives,
687+
)
688+
recall = base.divide_no_nan(
689+
self.true_positives,
690+
self.true_positives + self.false_negatives,
691+
)
692+
693+
b2 = self.beta**2
694+
numerator = (1 + b2) * (precision * recall)
695+
denominator = (b2 * precision) + recall
696+
697+
return base.divide_no_nan(numerator, denominator)

src/metrax/classification_metrics_test.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_aucroc(self, inputs, dtype):
273273
atol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-7,
274274
)
275275

276-
# Testing function for F-Beta classification
276+
# Testing function for F-Beta classification
277277
# name, output true, output prediction, threshold, beta
278278
@parameterized.named_parameters(
279279
('basic_f16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0),
@@ -307,6 +307,26 @@ def test_fbetascore(self, y_true, y_pred, threshold, beta):
307307
atol=atol,
308308
)
309309

310+
def test_fbeta_default_init(self):
311+
"""Test usage of FBetaScore with default beta (should be 1.0)."""
312+
y_true = jnp.array([1, 0, 1, 0], dtype=jnp.float32)
313+
y_pred = jnp.array([1, 0, 0, 1], dtype=jnp.float32)
314+
315+
# Init without beta argument
316+
metric = metrax.FBetaScore.from_model_output(
317+
predictions=y_pred,
318+
labels=y_true,
319+
)
320+
321+
# Should default to beta=1.0 (F1 Score)
322+
self.assertEqual(metric.beta, 1.0)
323+
324+
# Calculate expected F1
325+
# TP=1 (idx 0), FP=1 (idx 3), FN=1 (idx 2)
326+
# Precision = 1/2, Recall = 1/2
327+
# F1 = 2 * (0.5 * 0.5) / (0.5 + 0.5) = 0.5
328+
self.assertAlmostEqual(metric.compute(), 0.5)
329+
310330

311331
if __name__ == '__main__':
312332
absltest.main()

src/metrax/image_metrics.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,9 @@ def merge(self, other: 'Dice') -> 'Dice':
663663
)
664664

665665
def compute(self) -> jax.Array:
666-
"""Returns the final Dice coefficient."""
667-
epsilon = 1e-7
668-
return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon)
666+
return base.divide_no_nan(
667+
2.0 * self.intersection, self.sum_pred + self.sum_true
668+
)
669669

670670

671671
@flax.struct.dataclass
@@ -705,4 +705,3 @@ def from_model_output(
705705
cosine_similarity = dot_product / (predictions_norm * targets_norm)
706706

707707
return super().from_model_output(values=cosine_similarity)
708-

src/metrax/image_metrics_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,11 +541,12 @@ def test_dice(self, y_true, y_pred):
541541
y_pred = jnp.asarray(y_pred, jnp.float32)
542542

543543
# Manually compute expected Dice
544-
eps = 1e-7
545544
intersection = jnp.sum(y_true * y_pred)
546545
sum_pred = jnp.sum(y_pred)
547546
sum_true = jnp.sum(y_true)
548-
expected = (2.0 * intersection) / (sum_pred + sum_true + eps)
547+
expected = metrax.base.divide_no_nan(
548+
2.0 * intersection, sum_pred + sum_true
549+
)
549550

550551
# Compute using the metric class
551552
metric = metrax.Dice.from_model_output(predictions=y_pred, labels=y_true)

0 commit comments

Comments
 (0)