24
24
import java .util .logging .Level ;
25
25
import java .util .logging .Logger ;
26
26
import java .util .stream .Collectors ;
27
+ import java .util .stream .Stream ;
27
28
28
29
import org .junit .jupiter .api .Assertions ;
29
30
import org .junit .jupiter .api .BeforeAll ;
@@ -125,7 +126,7 @@ public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<M
125
126
// intersection(trueLabels, predicted) = true positives
126
127
// predicted - trueLabels = false positives
127
128
// trueLabels - predicted = false negatives
128
- return predicted .stream ().map (pred -> {
129
+ return Stream . concat ( predicted .stream ().map (pred -> {
129
130
if (trueLabels .contains (pred )) {
130
131
return mkPrediction (pred .getLabel (), pred .getLabel ());
131
132
} else if (trueLabels .size () == 1 ) {
@@ -134,7 +135,19 @@ public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<M
134
135
// arbitrarily pick first trueLabel
135
136
return mkPrediction (trueLabels .iterator ().next ().getLabel (), pred .getLabel ());
136
137
}
137
- });
138
+ }),
139
+ // partially represent false negatives by calling them false positives tied to some predicted label if there is one
140
+ trueLabels .stream ().filter (t -> !predicted .contains (t )).flatMap (fnTrueLabel -> {
141
+ if (predicted .isEmpty ()) {
142
+ // nothing to pin this on
143
+ return Stream .of ();
144
+ } else if (predicted .size () == 1 ) {
145
+ return Stream .of (mkPrediction (fnTrueLabel .getLabel (), predicted .iterator ().next ().getLabel ()));
146
+ } else {
147
+ // arbitrarily pick first predicted label
148
+ return Stream .of (mkPrediction (fnTrueLabel .getLabel (), predicted .iterator ().next ().getLabel ()));
149
+ }
150
+ }));
138
151
}).collect (Collectors .toList ());
139
152
}
140
153
0 commit comments