@@ -454,8 +454,10 @@ def compute(self) -> jax.Array:
454454
455455@flax .struct .dataclass
456456class AUCROC (clu_metrics .Metric ):
457- r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions`
458- and `labels`.
457+ r"""Computes Area Under the ROC curve (AUC-ROC).
458+
459+ Computes area under the receiver operation characteristic curve for binary
460+ classification given `predictions` and `labels`.
459461
460462 The ROC curve shows the tradeoff between the true positive rate (TPR) and
461463 false positive
@@ -579,147 +581,120 @@ def compute(self) -> jax.Array:
579581 # Threshold goes from 0 to 1, so trapezoid is negative.
580582 return jnp .trapezoid (tp_rate , fp_rate ) * - 1
581583
584+
582585@flax .struct .dataclass
583586class FBetaScore (clu_metrics .Metric ):
584- """
585- F-Beta score Metric class
586- Computes the F-Beta score for the binary classification given 'predictions' and 'labels'.
587-
588- Formula for F-Beta Score:
589- b2 = beta ** 2
590- f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall)
591-
592- F-Beta turns into the F1 Score when beta = 1.0
593-
594- Attributes:
595- true_positives: The count of true positive instances from the given data,
596- label, and threshold.
597- false_positives: The count of false positive instances from the given data,
598- label, and threshold.
599- false_negatives: The count of false negative instances from the given data,
600- label, and threshold.
601- beta: The beta value used in the F-Score metric
602- """
587+ """F-Beta score Metric class.
603588
604- true_positives : jax .Array
605- false_positives : jax .Array
606- false_negatives : jax .Array
607- beta : float = 1.0
608-
609- # Reset the variables for the class
610- @classmethod
611- def empty (cls ) -> 'FBetaScore' :
612- return cls (
613- true_positives = jnp .array (0 , jnp .float32 ),
614- false_positives = jnp .array (0 , jnp .float32 ),
615- false_negatives = jnp .array (0 , jnp .float32 ),
616- beta = 1.0 ,
617- )
618-
619- @classmethod
620- def from_model_output (
621- cls ,
622- predictions : jax .Array ,
623- labels : jax .Array ,
624- beta = beta ,
625- threshold = 0.5 ,) -> 'FBetaScore' :
626- """Updates the metric.
627- Note: When only predictions and labels are given, the score calculated
628- is the F1 score if the FBetaScore beta value has not been previously modified.
629-
630- Args:
631- predictions: A floating point 1D vector whose values are in the range [0,
632- 1]. The shape should be (batch_size,).
633- labels: True value. The value is expected to be 0 or 1. The shape should
634- be (batch_size,).
635- beta: beta value to use in the F-Score metric. A floating number.
636- threshold: threshold value to use in the F-Score metric. A floating number.
637-
638- Returns:
639- The updated FBetaScore object.
640-
641- Raises:
642- ValueError: If type of `labels` is wrong or the shapes of `predictions`
643- and `labels` are incompatible. If the beta or threshold are invalid values
644- an error is raised as well.
645- """
646-
647- # Ensure that beta is a floating number and a valid number
648- if not isinstance (beta , float ):
649- raise ValueError ('The "Beta" value must be a floating number.' )
650- if beta <= 0.0 :
651- raise ValueError ('The "Beta" value must be larger than 0.0.' )
652-
653- # Ensure that threshold is a floating number and a valid number
654- if not isinstance (threshold , float ):
655- raise ValueError ('The "Threshold" value must be a floating number.' )
656- if threshold < 0.0 or threshold > 1.0 :
657- raise ValueError ('The "Threshold" value must be between 0 and 1.' )
658-
659- # Modify predictions with the given threshold value
660- predictions = jnp .where (predictions >= threshold , 1 , 0 )
661-
662- # Assign the true_positive, false_positive, and false_negative their values
663- """
664- We are calculating these values manually instead of using Metrax's
665- precision and recall classes. This is because the Metrax versions end up
666- outputting a single numerical answer when we need an array of numbers.
667- """
668- true_positives = jnp .sum (predictions * labels , axis = 0 )
669- false_positives = jnp .sum (predictions * (1 - labels ), axis = 0 )
670- false_negatives = jnp .sum ((1 - predictions ) * labels , axis = 0 )
671-
672- return cls (true_positives = true_positives ,
673- false_positives = false_positives ,
674- false_negatives = false_negatives ,
675- beta = beta )
589+ Computes the F-Beta score for the binary classification given 'predictions'
590+ and 'labels'.
676591
592+ Formula for F-Beta Score:
593+ b2 = beta ** 2
594+ f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 +
595+ recall)
596+
597+ F-Beta turns into the F1 Score when beta = 1.0
598+
599+ Attributes:
600+ true_positives: The count of true positive instances from the given data,
601+ label, and threshold.
602+ false_positives: The count of false positive instances from the given data,
603+ label, and threshold.
604+ false_negatives: The count of false negative instances from the given data,
605+ label, and threshold.
606+ beta: The beta value used in the F-Score metric.
607+ """
608+
609+ true_positives : jax .Array
610+ false_positives : jax .Array
611+ false_negatives : jax .Array
612+ beta : float = 1.0
613+
614+ # Reset the variables for the class
615+ @classmethod
616+ def empty (cls ) -> 'FBetaScore' :
617+ return cls (
618+ true_positives = jnp .array (0 , jnp .float32 ),
619+ false_positives = jnp .array (0 , jnp .float32 ),
620+ false_negatives = jnp .array (0 , jnp .float32 ),
621+ beta = 1.0 ,
622+ )
623+
624+ @classmethod
625+ def from_model_output (
626+ cls ,
627+ predictions : jax .Array ,
628+ labels : jax .Array ,
629+ beta : float = 1.0 ,
630+ threshold : float = 0.5 ,
631+ ) -> 'FBetaScore' :
632+ """Updates the metric.
633+
634+ Note: When only predictions and labels are given, the score calculated is
635+ the F1 score if the FBetaScore beta value has not been previously modified.
636+
637+ Args:
638+ predictions: A floating point 1D vector whose values are in the range
639+ [0, 1]. The shape should be (batch_size,).
640+ labels: True value. The value is expected to be 0 or 1. The shape should
641+ be (batch_size,).
642+ beta: beta value to use in the F-Score metric. A floating number.
643+ threshold: threshold value to use in the F-Score metric. A floating
644+ number.
645+
646+ Returns:
647+ The updated FBetaScore object.
648+
649+ Raises:
650+ ValueError: If type of `labels` is wrong or the shapes of `predictions`
651+ and `labels` are incompatible. If the beta or threshold are invalid
652+ values
653+ an error is raised as well.
677654 """
678- This function is currently unused as the 'from_model_output' function can handle the whole
679- dataset without needing to split and merge them. I'm leaving this here for now incase we want to
680- repurpose this or need to change something that requires this function's use again. This function would need
681- to be reworked for it to work with the current implementation of this class.
682- """
683- # # Merge datasets together
684- # def merge(self, other: 'FBetaScore') -> 'FBetaScore':
685- #
686- # # Check if the incoming beta is the same value as the current beta
687- # if other.beta == self.beta:
688- # return type(self)(
689- # true_positives = self.true_positives + other.true_positives,
690- # false_positives = self.false_positives + other.false_positives,
691- # false_negatives = self.false_negatives + other.false_negatives,
692- # beta=self.beta,
693- # )
694- # else:
695- # raise ValueError('The "Beta" values between the two are not equal.')
696-
697- # Compute the F-Beta score metric
698- def compute (self ) -> jax .Array :
699-
700- # Epsilon fuz factor required to match with the keras version
701- epsilon = 1e-7
702-
703- # Manually calculate precision and recall
704- """
705- This is done in this manner since the metrax variants of precision
706- and recall output only single numbers. To match the keras value
707- we need an array of numbers to work with.
708- """
709- precision = jnp .divide (
710- self .true_positives ,
711- self .true_positives + self .false_positives + epsilon ,
712- )
713- recall = jnp .divide (
714- self .true_positives ,
715- self .true_positives + self .false_negatives + epsilon ,
716- )
717-
718- # Compute the numerator and denominator of the F-Score formula
719- b2 = self .beta ** 2
720- numerator = (1 + b2 ) * (precision * recall )
721- denominator = (b2 * precision ) + recall
722-
723- return base .divide_no_nan (
724- numerator , denominator + epsilon ,
725- )
655+
656+ # Ensure that beta is a floating number and a valid number
657+ if not isinstance (beta , float ):
658+ raise ValueError ('The "Beta" value must be a floating number.' )
659+ if beta <= 0.0 :
660+ raise ValueError ('The "Beta" value must be larger than 0.0.' )
661+
662+ # Ensure that threshold is a floating number and a valid number
663+ if not isinstance (threshold , float ):
664+ raise ValueError ('The "Threshold" value must be a floating number.' )
665+ if threshold < 0.0 or threshold > 1.0 :
666+ raise ValueError ('The "Threshold" value must be between 0 and 1.' )
667+
668+ # Modify predictions with the given threshold value
669+ predictions = jnp .where (predictions >= threshold , 1 , 0 )
670+
671+ # Assign the true_positive, false_positive, and false_negative their values
672+ true_positives = jnp .sum (predictions * labels , axis = 0 )
673+ false_positives = jnp .sum (predictions * (1 - labels ), axis = 0 )
674+ false_negatives = jnp .sum ((1 - predictions ) * labels , axis = 0 )
675+
676+ return cls (
677+ true_positives = true_positives ,
678+ false_positives = false_positives ,
679+ false_negatives = false_negatives ,
680+ beta = beta ,
681+ )
682+
683+ # Compute the F-Beta score metric
684+ def compute (self ) -> jax .Array :
685+
686+ # Compute the numerator and denominator of the F-Score formula
687+ precision = base .divide_no_nan (
688+ self .true_positives ,
689+ self .true_positives + self .false_positives ,
690+ )
691+ recall = base .divide_no_nan (
692+ self .true_positives ,
693+ self .true_positives + self .false_negatives ,
694+ )
695+
696+ b2 = self .beta ** 2
697+ numerator = (1 + b2 ) * (precision * recall )
698+ denominator = (b2 * precision ) + recall
699+
700+ return base .divide_no_nan (numerator , denominator )
0 commit comments