diff --git a/src/metrax/classification_metrics.py b/src/metrax/classification_metrics.py index 57f054d..cf3a070 100644 --- a/src/metrax/classification_metrics.py +++ b/src/metrax/classification_metrics.py @@ -454,8 +454,10 @@ def compute(self) -> jax.Array: @flax.struct.dataclass class AUCROC(clu_metrics.Metric): - r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions` - and `labels`. + r"""Computes Area Under the ROC curve (AUC-ROC). + + Computes area under the receiver operation characteristic curve for binary + classification given `predictions` and `labels`. The ROC curve shows the tradeoff between the true positive rate (TPR) and false positive @@ -579,147 +581,120 @@ def compute(self) -> jax.Array: # Threshold goes from 0 to 1, so trapezoid is negative. return jnp.trapezoid(tp_rate, fp_rate) * -1 + @flax.struct.dataclass class FBetaScore(clu_metrics.Metric): - """ - F-Beta score Metric class - Computes the F-Beta score for the binary classification given 'predictions' and 'labels'. - - Formula for F-Beta Score: - b2 = beta ** 2 - f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall) - - F-Beta turns into the F1 Score when beta = 1.0 - - Attributes: - true_positives: The count of true positive instances from the given data, - label, and threshold. - false_positives: The count of false positive instances from the given data, - label, and threshold. - false_negatives: The count of false negative instances from the given data, - label, and threshold. - beta: The beta value used in the F-Score metric - """ + """F-Beta score Metric class. - true_positives: jax.Array - false_positives: jax.Array - false_negatives: jax.Array - beta: float = 1.0 - - # Reset the variables for the class - @classmethod - def empty(cls) -> 'FBetaScore': - return cls( - true_positives = jnp.array(0, jnp.float32), - false_positives = jnp.array(0, jnp.float32), - false_negatives = jnp.array(0, jnp.float32), - beta=1.0, - ) - - @classmethod - def from_model_output( - cls, - predictions: jax.Array, - labels: jax.Array, - beta = beta, - threshold = 0.5,) -> 'FBetaScore': - """Updates the metric. - Note: When only predictions and labels are given, the score calculated - is the F1 score if the FBetaScore beta value has not been previously modified. - - Args: - predictions: A floating point 1D vector whose values are in the range [0, - 1]. The shape should be (batch_size,). - labels: True value. The value is expected to be 0 or 1. The shape should - be (batch_size,). - beta: beta value to use in the F-Score metric. A floating number. - threshold: threshold value to use in the F-Score metric. A floating number. - - Returns: - The updated FBetaScore object. - - Raises: - ValueError: If type of `labels` is wrong or the shapes of `predictions` - and `labels` are incompatible. If the beta or threshold are invalid values - an error is raised as well. - """ - - # Ensure that beta is a floating number and a valid number - if not isinstance(beta, float): - raise ValueError('The "Beta" value must be a floating number.') - if beta <= 0.0: - raise ValueError('The "Beta" value must be larger than 0.0.') - - # Ensure that threshold is a floating number and a valid number - if not isinstance(threshold, float): - raise ValueError('The "Threshold" value must be a floating number.') - if threshold < 0.0 or threshold > 1.0: - raise ValueError('The "Threshold" value must be between 0 and 1.') - - # Modify predictions with the given threshold value - predictions = jnp.where(predictions >= threshold, 1, 0) - - # Assign the true_positive, false_positive, and false_negative their values - """ - We are calculating these values manually instead of using Metrax's - precision and recall classes. This is because the Metrax versions end up - outputting a single numerical answer when we need an array of numbers. - """ - true_positives = jnp.sum(predictions * labels, axis = 0) - false_positives = jnp.sum(predictions * (1 - labels), axis = 0) - false_negatives = jnp.sum((1- predictions) * labels, axis = 0) - - return cls(true_positives = true_positives, - false_positives = false_positives, - false_negatives = false_negatives, - beta = beta) + Computes the F-Beta score for the binary classification given 'predictions' + and 'labels'. + Formula for F-Beta Score: + b2 = beta ** 2 + f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + + recall) + + F-Beta turns into the F1 Score when beta = 1.0 + + Attributes: + true_positives: The count of true positive instances from the given data, + label, and threshold. + false_positives: The count of false positive instances from the given data, + label, and threshold. + false_negatives: The count of false negative instances from the given data, + label, and threshold. + beta: The beta value used in the F-Score metric. + """ + + true_positives: jax.Array + false_positives: jax.Array + false_negatives: jax.Array + beta: float = 1.0 + + # Reset the variables for the class + @classmethod + def empty(cls) -> 'FBetaScore': + return cls( + true_positives=jnp.array(0, jnp.float32), + false_positives=jnp.array(0, jnp.float32), + false_negatives=jnp.array(0, jnp.float32), + beta=1.0, + ) + + @classmethod + def from_model_output( + cls, + predictions: jax.Array, + labels: jax.Array, + beta: float = 1.0, + threshold: float = 0.5, + ) -> 'FBetaScore': + """Updates the metric. + + Note: When only predictions and labels are given, the score calculated is + the F1 score if the FBetaScore beta value has not been previously modified. + + Args: + predictions: A floating point 1D vector whose values are in the range + [0, 1]. The shape should be (batch_size,). + labels: True value. The value is expected to be 0 or 1. The shape should + be (batch_size,). + beta: beta value to use in the F-Score metric. A floating number. + threshold: threshold value to use in the F-Score metric. A floating + number. + + Returns: + The updated FBetaScore object. + + Raises: + ValueError: If type of `labels` is wrong or the shapes of `predictions` + and `labels` are incompatible. If the beta or threshold are invalid + values + an error is raised as well. """ - This function is currently unused as the 'from_model_output' function can handle the whole - dataset without needing to split and merge them. I'm leaving this here for now incase we want to - repurpose this or need to change something that requires this function's use again. This function would need - to be reworked for it to work with the current implementation of this class. - """ - # # Merge datasets together - # def merge(self, other: 'FBetaScore') -> 'FBetaScore': - # - # # Check if the incoming beta is the same value as the current beta - # if other.beta == self.beta: - # return type(self)( - # true_positives = self.true_positives + other.true_positives, - # false_positives = self.false_positives + other.false_positives, - # false_negatives = self.false_negatives + other.false_negatives, - # beta=self.beta, - # ) - # else: - # raise ValueError('The "Beta" values between the two are not equal.') - - # Compute the F-Beta score metric - def compute(self) -> jax.Array: - - # Epsilon fuz factor required to match with the keras version - epsilon = 1e-7 - - # Manually calculate precision and recall - """ - This is done in this manner since the metrax variants of precision - and recall output only single numbers. To match the keras value - we need an array of numbers to work with. - """ - precision = jnp.divide( - self.true_positives, - self.true_positives + self.false_positives + epsilon, - ) - recall = jnp.divide( - self.true_positives, - self.true_positives + self.false_negatives + epsilon, - ) - - # Compute the numerator and denominator of the F-Score formula - b2 = self.beta ** 2 - numerator = (1 + b2) * (precision * recall) - denominator = (b2 * precision) + recall - - return base.divide_no_nan( - numerator, denominator + epsilon, - ) + + # Ensure that beta is a floating number and a valid number + if not isinstance(beta, float): + raise ValueError('The "Beta" value must be a floating number.') + if beta <= 0.0: + raise ValueError('The "Beta" value must be larger than 0.0.') + + # Ensure that threshold is a floating number and a valid number + if not isinstance(threshold, float): + raise ValueError('The "Threshold" value must be a floating number.') + if threshold < 0.0 or threshold > 1.0: + raise ValueError('The "Threshold" value must be between 0 and 1.') + + # Modify predictions with the given threshold value + predictions = jnp.where(predictions >= threshold, 1, 0) + + # Assign the true_positive, false_positive, and false_negative their values + true_positives = jnp.sum(predictions * labels, axis=0) + false_positives = jnp.sum(predictions * (1 - labels), axis=0) + false_negatives = jnp.sum((1 - predictions) * labels, axis=0) + + return cls( + true_positives=true_positives, + false_positives=false_positives, + false_negatives=false_negatives, + beta=beta, + ) + + # Compute the F-Beta score metric + def compute(self) -> jax.Array: + + # Compute the numerator and denominator of the F-Score formula + precision = base.divide_no_nan( + self.true_positives, + self.true_positives + self.false_positives, + ) + recall = base.divide_no_nan( + self.true_positives, + self.true_positives + self.false_negatives, + ) + + b2 = self.beta**2 + numerator = (1 + b2) * (precision * recall) + denominator = (b2 * precision) + recall + + return base.divide_no_nan(numerator, denominator) diff --git a/src/metrax/classification_metrics_test.py b/src/metrax/classification_metrics_test.py index 2e54451..c9255c3 100644 --- a/src/metrax/classification_metrics_test.py +++ b/src/metrax/classification_metrics_test.py @@ -15,9 +15,10 @@ """Tests for metrax classification metrics.""" import os -os.environ['KERAS_BACKEND'] = 'jax' -from absl.testing import absltest +os.environ['KERAS_BACKEND'] = 'jax' # pylint: disable=g-import-not-at-top + +from absl.testing import absltest # pylint: disable=g-import-not-at-top from absl.testing import parameterized import jax.numpy as jnp import keras @@ -47,6 +48,7 @@ (BATCHES, 1), ).astype(np.float32) + class ClassificationMetricsTest(parameterized.TestCase): def test_precision_empty(self): @@ -273,17 +275,53 @@ def test_aucroc(self, inputs, dtype): atol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-7, ) -# Testing function for F-Beta classification + # Testing function for F-Beta classification # name, output true, output prediction, threshold, beta @parameterized.named_parameters( ('basic_f16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0), ('basic_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0), - ('low_threshold_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0), - ('high_threshold_bf16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0), - ('batch_size_one_beta_1.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0), - ('high_threshold_f16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0), - ('high_threshold_f32_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0), - ('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0), + ( + 'low_threshold_f32_beta_1.0', + OUTPUT_LABELS, + OUTPUT_PREDS_F32, + 0.1, + 1.0, + ), + ( + 'high_threshold_bf16_beta_1.0', + OUTPUT_LABELS, + OUTPUT_PREDS_BF16, + 0.7, + 1.0, + ), + ( + 'batch_size_one_beta_1.0', + OUTPUT_LABELS_BS1, + OUTPUT_PREDS_BS1, + 0.5, + 1.0, + ), + ( + 'high_threshold_f16_beta_2.0', + OUTPUT_LABELS, + OUTPUT_PREDS_F16, + 0.7, + 2.0, + ), + ( + 'high_threshold_f32_beta_2.0', + OUTPUT_LABELS, + OUTPUT_PREDS_F32, + 0.7, + 2.0, + ), + ( + 'low_threshold_bf16_beta_2.0', + OUTPUT_LABELS, + OUTPUT_PREDS_BF16, + 0.1, + 2.0, + ), ('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0), ('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0), ) @@ -307,6 +345,26 @@ def test_fbetascore(self, y_true, y_pred, threshold, beta): atol=atol, ) + def test_fbeta_default_init(self): + """Test usage of FBetaScore with default beta (should be 1.0).""" + y_true = jnp.array([1, 0, 1, 0], dtype=jnp.float32) + y_pred = jnp.array([1, 0, 0, 1], dtype=jnp.float32) + + # Init without beta argument + metric = metrax.FBetaScore.from_model_output( + predictions=y_pred, + labels=y_true, + ) + + # Should default to beta=1.0 (F1 Score) + self.assertEqual(metric.beta, 1.0) + + # Calculate expected F1 + # TP=1 (idx 0), FP=1 (idx 3), FN=1 (idx 2) + # Precision = 1/2, Recall = 1/2 + # F1 = 2 * (0.5 * 0.5) / (0.5 + 0.5) = 0.5 + self.assertAlmostEqual(metric.compute(), 0.5) + if __name__ == '__main__': absltest.main() diff --git a/src/metrax/image_metrics.py b/src/metrax/image_metrics.py index c19a75e..e6f3479 100644 --- a/src/metrax/image_metrics.py +++ b/src/metrax/image_metrics.py @@ -539,14 +539,15 @@ def _calculate_psnr( """Computes PSNR (Peak Signal-to-Noise Ratio) values. Args: - img1: Predicted images, shape ``(batch, H, W, C)``. - img2: Ground‑truth images, same shape as ``img1``. - max_val: Dynamic range of the images (e.g. ``1.0`` or ``255``). - eps: Small constant to avoid ``log(0)`` when images are identical. + img1: Predicted images, shape ``(batch, H, W, C)``. + img2: Ground‑truth images, same shape as ``img1``. + max_val: Dynamic range of the images (e.g. ``1.0`` or ``255``). + eps: Small constant to avoid ``log(0)`` when images are identical. - Returns: - A 1D JAX array of shape ``(batch,)`` containing PSNR in dB. + Returns: + A 1D JAX array of shape ``(batch,)`` containing PSNR in dB. """ + if img1.shape != img2.shape: raise ValueError( f'Input images must have the same shape, got {img1.shape} and' @@ -663,9 +664,9 @@ def merge(self, other: 'Dice') -> 'Dice': ) def compute(self) -> jax.Array: - """Returns the final Dice coefficient.""" - epsilon = 1e-7 - return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon) + return base.divide_no_nan( + 2.0 * self.intersection, self.sum_pred + self.sum_true + ) @flax.struct.dataclass @@ -705,4 +706,3 @@ def from_model_output( cosine_similarity = dot_product / (predictions_norm * targets_norm) return super().from_model_output(values=cosine_similarity) - diff --git a/src/metrax/image_metrics_test.py b/src/metrax/image_metrics_test.py index a3f9fbf..22fea9e 100644 --- a/src/metrax/image_metrics_test.py +++ b/src/metrax/image_metrics_test.py @@ -15,9 +15,10 @@ """Tests for metrax image metrics.""" import os -os.environ['KERAS_BACKEND'] = 'jax' -from absl.testing import absltest +os.environ['KERAS_BACKEND'] = 'jax' # pylint: disable=g-import-not-at-top + +from absl.testing import absltest # pylint: disable=g-import-not-at-top from absl.testing import parameterized import jax.numpy as jnp import keras @@ -77,7 +78,8 @@ B_IOU, H_IOU, W_IOU = 2, 4, 4 # Common batch, height, width for IoU tests -# Case IoU 1: Binary segmentation (num_classes=2), target_class_ids=[1] (foreground) +# Case IoU 1: Binary segmentation (num_classes=2), target_class_ids=[1] +# (foreground) # Targets: (Batch, Height, Width) TARGETS_IOU_1 = np.array( [ @@ -118,7 +120,8 @@ [1] ) # Expected Keras/Metrax result: mean([2/6, 2/6]) = 1/3 -# Case IoU 2: Multi-class (num_classes=3), target_class_ids=[0, 2] (mean over these two) +# Case IoU 2: Multi-class (num_classes=3), target_class_ids=[0, 2] +# (mean over these two) TARGETS_IOU_2 = np.array( [ [[0, 0, 1, 1], [0, 1, 1, 2], [2, 2, 1, 0], [0, 0, 2, 2]], # B1 @@ -136,7 +139,8 @@ NUM_CLASSES_IOU_2 = 3 TARGET_CLASS_IDS_IOU_2 = np.array([0, 2]) -# Case IoU 3: Perfect overlap for target class [1] (using a smaller H, W for simplicity) +# Case IoU 3: Perfect overlap for target class [1] +# (using a smaller H, W for simplicity) _H_IOU3, _W_IOU3 = 3, 3 TARGETS_IOU_3 = np.array( [ @@ -541,11 +545,12 @@ def test_dice(self, y_true, y_pred): y_pred = jnp.asarray(y_pred, jnp.float32) # Manually compute expected Dice - eps = 1e-7 intersection = jnp.sum(y_true * y_pred) sum_pred = jnp.sum(y_pred) sum_true = jnp.sum(y_true) - expected = (2.0 * intersection) / (sum_pred + sum_true + eps) + expected = metrax.base.divide_no_nan( + 2.0 * intersection, sum_pred + sum_true + ) # Compute using the metric class metric = metrax.Dice.from_model_output(predictions=y_pred, labels=y_true)