24
24
import java .util .logging .Level ;
25
25
import java .util .logging .Logger ;
26
26
import java .util .stream .Collectors ;
27
+ import java .util .stream .Stream ;
27
28
28
29
import org .junit .jupiter .api .Assertions ;
29
30
import org .junit .jupiter .api .BeforeAll ;
@@ -118,14 +119,19 @@ public static LabelConfusionMatrix singleLabelConfusionMatrix(final List<Predict
118
119
}
119
120
120
121
public static List <Prediction <Label >> mkSingleLabelPredictions (List <Prediction <MultiLabel >> predictions ) {
122
+ return mkSingleLabelPredictions (predictions , false );
123
+ }
124
+
125
+ public static List <Prediction <Label >> mkSingleLabelPredictions (List <Prediction <MultiLabel >> predictions ,
126
+ final boolean falseNegativeHeuristic ) {
121
127
return predictions .stream ()
122
128
.flatMap (p -> {
123
129
final Set <Label > trueLabels = p .getExample ().getOutput ().getLabelSet ();
124
130
final Set <Label > predicted = p .getOutput ().getLabelSet ();
125
131
// intersection(trueLabels, predicted) = true positives
126
132
// predicted - trueLabels = false positives
127
133
// trueLabels - predicted = false negatives
128
- return predicted .stream ().map (pred -> {
134
+ return Stream . concat ( predicted .stream ().map (pred -> {
129
135
if (trueLabels .contains (pred )) {
130
136
return mkPrediction (pred .getLabel (), pred .getLabel ());
131
137
} else if (trueLabels .size () == 1 ) {
@@ -134,7 +140,21 @@ public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<M
134
140
// arbitrarily pick first trueLabel
135
141
return mkPrediction (trueLabels .iterator ().next ().getLabel (), pred .getLabel ());
136
142
}
137
- });
143
+ }),
144
+ !falseNegativeHeuristic ? Stream .of () :
145
+ // partially represent false negatives by calling them false positives tied to some predicted label if there is one
146
+ trueLabels .stream ().filter (t -> !predicted .contains (t )).flatMap (fnTrueLabel -> {
147
+ if (predicted .isEmpty ()) {
148
+ // nothing to pin this on
149
+ return Stream .of ();
150
+ } else if (predicted .size () == 1 ) {
151
+ return Stream .of (mkPrediction (fnTrueLabel .getLabel (), predicted .iterator ().next ().getLabel ()));
152
+ } else {
153
+ // arbitrarily pick first predicted label
154
+ return Stream .of (mkPrediction (fnTrueLabel .getLabel (), predicted .iterator ().next ().getLabel ()));
155
+ }
156
+ })
157
+ );
138
158
}).collect (Collectors .toList ());
139
159
}
140
160
0 commit comments