@@ -581,145 +581,117 @@ def compute(self) -> jax.Array:
581581
582582@flax .struct .dataclass
583583class FBetaScore (clu_metrics .Metric ):
584- """
585- F-Beta score Metric class
586- Computes the F-Beta score for the binary classification given 'predictions' and 'labels'.
587-
588- Formula for F-Beta Score:
589- b2 = beta ** 2
590- f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall)
591-
592- F-Beta turns into the F1 Score when beta = 1.0
593-
594- Attributes:
595- true_positives: The count of true positive instances from the given data,
596- label, and threshold.
597- false_positives: The count of false positive instances from the given data,
598- label, and threshold.
599- false_negatives: The count of false negative instances from the given data,
600- label, and threshold.
601- beta: The beta value used in the F-Score metric
602- """
584+ """F-Beta score Metric class.
603585
604- true_positives : jax .Array
605- false_positives : jax .Array
606- false_negatives : jax .Array
607- beta : float = 1.0
608-
609- # Reset the variables for the class
610- @classmethod
611- def empty (cls ) -> 'FBetaScore' :
612- return cls (
613- true_positives = jnp .array (0 , jnp .float32 ),
614- false_positives = jnp .array (0 , jnp .float32 ),
615- false_negatives = jnp .array (0 , jnp .float32 ),
616- beta = 1.0 ,
617- )
618-
619- @classmethod
620- def from_model_output (
621- cls ,
622- predictions : jax .Array ,
623- labels : jax .Array ,
624- beta = beta ,
625- threshold = 0.5 ,) -> 'FBetaScore' :
626- """Updates the metric.
627- Note: When only predictions and labels are given, the score calculated
628- is the F1 score if the FBetaScore beta value has not been previously modified.
629-
630- Args:
631- predictions: A floating point 1D vector whose values are in the range [0,
632- 1]. The shape should be (batch_size,).
633- labels: True value. The value is expected to be 0 or 1. The shape should
634- be (batch_size,).
635- beta: beta value to use in the F-Score metric. A floating number.
636- threshold: threshold value to use in the F-Score metric. A floating number.
637-
638- Returns:
639- The updated FBetaScore object.
640-
641- Raises:
642- ValueError: If type of `labels` is wrong or the shapes of `predictions`
643- and `labels` are incompatible. If the beta or threshold are invalid values
644- an error is raised as well.
645- """
646-
647- # Ensure that beta is a floating number and a valid number
648- if not isinstance (beta , float ):
649- raise ValueError ('The "Beta" value must be a floating number.' )
650- if beta <= 0.0 :
651- raise ValueError ('The "Beta" value must be larger than 0.0.' )
652-
653- # Ensure that threshold is a floating number and a valid number
654- if not isinstance (threshold , float ):
655- raise ValueError ('The "Threshold" value must be a floating number.' )
656- if threshold < 0.0 or threshold > 1.0 :
657- raise ValueError ('The "Threshold" value must be between 0 and 1.' )
658-
659- # Modify predictions with the given threshold value
660- predictions = jnp .where (predictions >= threshold , 1 , 0 )
661-
662- # Assign the true_positive, false_positive, and false_negative their values
663- """
664- We are calculating these values manually instead of using Metrax's
665- precision and recall classes. This is because the Metrax versions end up
666- outputting a single numerical answer when we need an array of numbers.
667- """
668- true_positives = jnp .sum (predictions * labels , axis = 0 )
669- false_positives = jnp .sum (predictions * (1 - labels ), axis = 0 )
670- false_negatives = jnp .sum ((1 - predictions ) * labels , axis = 0 )
671-
672- return cls (true_positives = true_positives ,
673- false_positives = false_positives ,
674- false_negatives = false_negatives ,
675- beta = beta )
586+ Computes the F-Beta score for the binary classification given 'predictions'
587+ and 'labels'.
676588
589+ Formula for F-Beta Score:
590+ b2 = beta ** 2
591+ f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 +
592+ recall)
593+
594+ F-Beta turns into the F1 Score when beta = 1.0
595+
596+ Attributes:
597+ true_positives: The count of true positive instances from the given data,
598+ label, and threshold.
599+ false_positives: The count of false positive instances from the given
600+ data, label, and threshold.
601+ false_negatives: The count of false negative instances from the given
602+ data, label, and threshold.
603+ beta: The beta value used in the F-Score metric
604+ """
605+
606+ true_positives : jax .Array
607+ false_positives : jax .Array
608+ false_negatives : jax .Array
609+ beta : float = 1.0
610+
611+ # Reset the variables for the class
612+ @classmethod
613+ def empty (cls ) -> 'FBetaScore' :
614+ return cls (
615+ true_positives = jnp .array (0 , jnp .float32 ),
616+ false_positives = jnp .array (0 , jnp .float32 ),
617+ false_negatives = jnp .array (0 , jnp .float32 ),
618+ beta = 1.0 ,
619+ )
620+
621+ @classmethod
622+ def from_model_output (
623+ cls ,
624+ predictions : jax .Array ,
625+ labels : jax .Array ,
626+ beta : float = 1.0 ,
627+ threshold : float = 0.5 ,
628+ ) -> 'FBetaScore' :
629+ """Updates the metric.
630+
631+ Note: When only predictions and labels are given, the score calculated is
632+ the F1 score if the FBetaScore beta value has not been previously modified.
633+
634+ Args:
635+ predictions: A floating point 1D vector whose values are in the range
636+ [0, 1]. The shape should be (batch_size,).
637+ labels: True value. The value is expected to be 0 or 1. The shape should
638+ be (batch_size,).
639+ beta: beta value to use in the F-Score metric. A floating number.
640+ threshold: threshold value to use in the F-Score metric. A floating
641+ number.
642+
643+ Returns:
644+ The updated FBetaScore object.
645+
646+ Raises:
647+ ValueError: If type of `labels` is wrong or the shapes of `predictions`
648+ and `labels` are incompatible. If the beta or threshold are invalid
649+ values
650+ an error is raised as well.
677651 """
678- This function is currently unused as the 'from_model_output' function can handle the whole
679- dataset without needing to split and merge them. I'm leaving this here for now incase we want to
680- repurpose this or need to change something that requires this function's use again. This function would need
681- to be reworked for it to work with the current implementation of this class.
682- """
683- # # Merge datasets together
684- # def merge(self, other: 'FBetaScore') -> 'FBetaScore':
685- #
686- # # Check if the incoming beta is the same value as the current beta
687- # if other.beta == self.beta:
688- # return type(self)(
689- # true_positives = self.true_positives + other.true_positives,
690- # false_positives = self.false_positives + other.false_positives,
691- # false_negatives = self.false_negatives + other.false_negatives,
692- # beta=self.beta,
693- # )
694- # else:
695- # raise ValueError('The "Beta" values between the two are not equal.')
696-
697- # Compute the F-Beta score metric
698- def compute (self ) -> jax .Array :
699-
700- # Epsilon fuz factor required to match with the keras version
701- epsilon = 1e-7
702-
703- # Manually calculate precision and recall
704- """
705- This is done in this manner since the metrax variants of precision
706- and recall output only single numbers. To match the keras value
707- we need an array of numbers to work with.
708- """
709- precision = jnp .divide (
710- self .true_positives ,
711- self .true_positives + self .false_positives + epsilon ,
712- )
713- recall = jnp .divide (
714- self .true_positives ,
715- self .true_positives + self .false_negatives + epsilon ,
716- )
717-
718- # Compute the numerator and denominator of the F-Score formula
719- b2 = self .beta ** 2
720- numerator = (1 + b2 ) * (precision * recall )
721- denominator = (b2 * precision ) + recall
722-
723- return base .divide_no_nan (
724- numerator , denominator + epsilon ,
725- )
652+
653+ # Ensure that beta is a floating number and a valid number
654+ if not isinstance (beta , float ):
655+ raise ValueError ('The "Beta" value must be a floating number.' )
656+ if beta <= 0.0 :
657+ raise ValueError ('The "Beta" value must be larger than 0.0.' )
658+
659+ # Ensure that threshold is a floating number and a valid number
660+ if not isinstance (threshold , float ):
661+ raise ValueError ('The "Threshold" value must be a floating number.' )
662+ if threshold < 0.0 or threshold > 1.0 :
663+ raise ValueError ('The "Threshold" value must be between 0 and 1.' )
664+
665+ # Modify predictions with the given threshold value
666+ predictions = jnp .where (predictions >= threshold , 1 , 0 )
667+
668+ # Assign the true_positive, false_positive, and false_negative their values
669+ true_positives = jnp .sum (predictions * labels , axis = 0 )
670+ false_positives = jnp .sum (predictions * (1 - labels ), axis = 0 )
671+ false_negatives = jnp .sum ((1 - predictions ) * labels , axis = 0 )
672+
673+ return cls (
674+ true_positives = true_positives ,
675+ false_positives = false_positives ,
676+ false_negatives = false_negatives ,
677+ beta = beta ,
678+ )
679+
680+ # Compute the F-Beta score metric
681+ def compute (self ) -> jax .Array :
682+
683+ # Compute the numerator and denominator of the F-Score formula
684+ precision = base .divide_no_nan (
685+ self .true_positives ,
686+ self .true_positives + self .false_positives ,
687+ )
688+ recall = base .divide_no_nan (
689+ self .true_positives ,
690+ self .true_positives + self .false_negatives ,
691+ )
692+
693+ b2 = self .beta ** 2
694+ numerator = (1 + b2 ) * (precision * recall )
695+ denominator = (b2 * precision ) + recall
696+
697+ return base .divide_no_nan (numerator , denominator )
0 commit comments