Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 116 additions & 141 deletions src/metrax/classification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,10 @@ def compute(self) -> jax.Array:

@flax.struct.dataclass
class AUCROC(clu_metrics.Metric):
r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions`
and `labels`.
r"""Computes Area Under the ROC curve (AUC-ROC).

Computes area under the receiver operation characteristic curve for binary
classification given `predictions` and `labels`.

The ROC curve shows the tradeoff between the true positive rate (TPR) and
false positive
Expand Down Expand Up @@ -579,147 +581,120 @@ def compute(self) -> jax.Array:
# Threshold goes from 0 to 1, so trapezoid is negative.
return jnp.trapezoid(tp_rate, fp_rate) * -1


@flax.struct.dataclass
class FBetaScore(clu_metrics.Metric):
"""
F-Beta score Metric class
Computes the F-Beta score for the binary classification given 'predictions' and 'labels'.

Formula for F-Beta Score:
b2 = beta ** 2
f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall)

F-Beta turns into the F1 Score when beta = 1.0

Attributes:
true_positives: The count of true positive instances from the given data,
label, and threshold.
false_positives: The count of false positive instances from the given data,
label, and threshold.
false_negatives: The count of false negative instances from the given data,
label, and threshold.
beta: The beta value used in the F-Score metric
"""
"""F-Beta score Metric class.

true_positives: jax.Array
false_positives: jax.Array
false_negatives: jax.Array
beta: float = 1.0

# Reset the variables for the class
@classmethod
def empty(cls) -> 'FBetaScore':
return cls(
true_positives = jnp.array(0, jnp.float32),
false_positives = jnp.array(0, jnp.float32),
false_negatives = jnp.array(0, jnp.float32),
beta=1.0,
)

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
labels: jax.Array,
beta = beta,
threshold = 0.5,) -> 'FBetaScore':
"""Updates the metric.
Note: When only predictions and labels are given, the score calculated
is the F1 score if the FBetaScore beta value has not been previously modified.

Args:
predictions: A floating point 1D vector whose values are in the range [0,
1]. The shape should be (batch_size,).
labels: True value. The value is expected to be 0 or 1. The shape should
be (batch_size,).
beta: beta value to use in the F-Score metric. A floating number.
threshold: threshold value to use in the F-Score metric. A floating number.

Returns:
The updated FBetaScore object.

Raises:
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible. If the beta or threshold are invalid values
an error is raised as well.
"""

# Ensure that beta is a floating number and a valid number
if not isinstance(beta, float):
raise ValueError('The "Beta" value must be a floating number.')
if beta <= 0.0:
raise ValueError('The "Beta" value must be larger than 0.0.')

# Ensure that threshold is a floating number and a valid number
if not isinstance(threshold, float):
raise ValueError('The "Threshold" value must be a floating number.')
if threshold < 0.0 or threshold > 1.0:
raise ValueError('The "Threshold" value must be between 0 and 1.')

# Modify predictions with the given threshold value
predictions = jnp.where(predictions >= threshold, 1, 0)

# Assign the true_positive, false_positive, and false_negative their values
"""
We are calculating these values manually instead of using Metrax's
precision and recall classes. This is because the Metrax versions end up
outputting a single numerical answer when we need an array of numbers.
"""
true_positives = jnp.sum(predictions * labels, axis = 0)
false_positives = jnp.sum(predictions * (1 - labels), axis = 0)
false_negatives = jnp.sum((1- predictions) * labels, axis = 0)

return cls(true_positives = true_positives,
false_positives = false_positives,
false_negatives = false_negatives,
beta = beta)
Computes the F-Beta score for the binary classification given 'predictions'
and 'labels'.

Formula for F-Beta Score:
b2 = beta ** 2
f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 +
recall)

F-Beta turns into the F1 Score when beta = 1.0

Attributes:
true_positives: The count of true positive instances from the given data,
label, and threshold.
false_positives: The count of false positive instances from the given data,
label, and threshold.
false_negatives: The count of false negative instances from the given data,
label, and threshold.
beta: The beta value used in the F-Score metric.
"""

true_positives: jax.Array
false_positives: jax.Array
false_negatives: jax.Array
beta: float = 1.0

# Reset the variables for the class
@classmethod
def empty(cls) -> 'FBetaScore':
return cls(
true_positives=jnp.array(0, jnp.float32),
false_positives=jnp.array(0, jnp.float32),
false_negatives=jnp.array(0, jnp.float32),
beta=1.0,
)

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
labels: jax.Array,
beta: float = 1.0,
threshold: float = 0.5,
) -> 'FBetaScore':
"""Updates the metric.

Note: When only predictions and labels are given, the score calculated is
the F1 score if the FBetaScore beta value has not been previously modified.

Args:
predictions: A floating point 1D vector whose values are in the range
[0, 1]. The shape should be (batch_size,).
labels: True value. The value is expected to be 0 or 1. The shape should
be (batch_size,).
beta: beta value to use in the F-Score metric. A floating number.
threshold: threshold value to use in the F-Score metric. A floating
number.

Returns:
The updated FBetaScore object.

Raises:
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible. If the beta or threshold are invalid
values
an error is raised as well.
"""
This function is currently unused as the 'from_model_output' function can handle the whole
dataset without needing to split and merge them. I'm leaving this here for now incase we want to
repurpose this or need to change something that requires this function's use again. This function would need
to be reworked for it to work with the current implementation of this class.
"""
# # Merge datasets together
# def merge(self, other: 'FBetaScore') -> 'FBetaScore':
#
# # Check if the incoming beta is the same value as the current beta
# if other.beta == self.beta:
# return type(self)(
# true_positives = self.true_positives + other.true_positives,
# false_positives = self.false_positives + other.false_positives,
# false_negatives = self.false_negatives + other.false_negatives,
# beta=self.beta,
# )
# else:
# raise ValueError('The "Beta" values between the two are not equal.')

# Compute the F-Beta score metric
def compute(self) -> jax.Array:

# Epsilon fuz factor required to match with the keras version
epsilon = 1e-7

# Manually calculate precision and recall
"""
This is done in this manner since the metrax variants of precision
and recall output only single numbers. To match the keras value
we need an array of numbers to work with.
"""
precision = jnp.divide(
self.true_positives,
self.true_positives + self.false_positives + epsilon,
)
recall = jnp.divide(
self.true_positives,
self.true_positives + self.false_negatives + epsilon,
)

# Compute the numerator and denominator of the F-Score formula
b2 = self.beta ** 2
numerator = (1 + b2) * (precision * recall)
denominator = (b2 * precision) + recall

return base.divide_no_nan(
numerator, denominator + epsilon,
)

# Ensure that beta is a floating number and a valid number
if not isinstance(beta, float):
raise ValueError('The "Beta" value must be a floating number.')
if beta <= 0.0:
raise ValueError('The "Beta" value must be larger than 0.0.')

# Ensure that threshold is a floating number and a valid number
if not isinstance(threshold, float):
raise ValueError('The "Threshold" value must be a floating number.')
if threshold < 0.0 or threshold > 1.0:
raise ValueError('The "Threshold" value must be between 0 and 1.')

# Modify predictions with the given threshold value
predictions = jnp.where(predictions >= threshold, 1, 0)

# Assign the true_positive, false_positive, and false_negative their values
true_positives = jnp.sum(predictions * labels, axis=0)
false_positives = jnp.sum(predictions * (1 - labels), axis=0)
false_negatives = jnp.sum((1 - predictions) * labels, axis=0)

return cls(
true_positives=true_positives,
false_positives=false_positives,
false_negatives=false_negatives,
beta=beta,
)

# Compute the F-Beta score metric
def compute(self) -> jax.Array:

# Compute the numerator and denominator of the F-Score formula
precision = base.divide_no_nan(
self.true_positives,
self.true_positives + self.false_positives,
)
recall = base.divide_no_nan(
self.true_positives,
self.true_positives + self.false_negatives,
)

b2 = self.beta**2
numerator = (1 + b2) * (precision * recall)
denominator = (b2 * precision) + recall

return base.divide_no_nan(numerator, denominator)
76 changes: 67 additions & 9 deletions src/metrax/classification_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"""Tests for metrax classification metrics."""

import os
os.environ['KERAS_BACKEND'] = 'jax'

from absl.testing import absltest
os.environ['KERAS_BACKEND'] = 'jax' # pylint: disable=g-import-not-at-top

from absl.testing import absltest # pylint: disable=g-import-not-at-top
from absl.testing import parameterized
import jax.numpy as jnp
import keras
Expand Down Expand Up @@ -47,6 +48,7 @@
(BATCHES, 1),
).astype(np.float32)


class ClassificationMetricsTest(parameterized.TestCase):

def test_precision_empty(self):
Expand Down Expand Up @@ -273,17 +275,53 @@ def test_aucroc(self, inputs, dtype):
atol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-7,
)

# Testing function for F-Beta classification
# Testing function for F-Beta classification
# name, output true, output prediction, threshold, beta
@parameterized.named_parameters(
('basic_f16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0),
('basic_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0),
('low_threshold_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0),
('high_threshold_bf16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0),
('batch_size_one_beta_1.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0),
('high_threshold_f16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0),
('high_threshold_f32_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0),
('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0),
(
'low_threshold_f32_beta_1.0',
OUTPUT_LABELS,
OUTPUT_PREDS_F32,
0.1,
1.0,
),
(
'high_threshold_bf16_beta_1.0',
OUTPUT_LABELS,
OUTPUT_PREDS_BF16,
0.7,
1.0,
),
(
'batch_size_one_beta_1.0',
OUTPUT_LABELS_BS1,
OUTPUT_PREDS_BS1,
0.5,
1.0,
),
(
'high_threshold_f16_beta_2.0',
OUTPUT_LABELS,
OUTPUT_PREDS_F16,
0.7,
2.0,
),
(
'high_threshold_f32_beta_2.0',
OUTPUT_LABELS,
OUTPUT_PREDS_F32,
0.7,
2.0,
),
(
'low_threshold_bf16_beta_2.0',
OUTPUT_LABELS,
OUTPUT_PREDS_BF16,
0.1,
2.0,
),
('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0),
('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0),
)
Expand All @@ -307,6 +345,26 @@ def test_fbetascore(self, y_true, y_pred, threshold, beta):
atol=atol,
)

def test_fbeta_default_init(self):
"""Test usage of FBetaScore with default beta (should be 1.0)."""
y_true = jnp.array([1, 0, 1, 0], dtype=jnp.float32)
y_pred = jnp.array([1, 0, 0, 1], dtype=jnp.float32)

# Init without beta argument
metric = metrax.FBetaScore.from_model_output(
predictions=y_pred,
labels=y_true,
)

# Should default to beta=1.0 (F1 Score)
self.assertEqual(metric.beta, 1.0)

# Calculate expected F1
# TP=1 (idx 0), FP=1 (idx 3), FN=1 (idx 2)
# Precision = 1/2, Recall = 1/2
# F1 = 2 * (0.5 * 0.5) / (0.5 + 0.5) = 0.5
self.assertAlmostEqual(metric.compute(), 0.5)


if __name__ == '__main__':
absltest.main()
Loading