Skip to content

Commit b5b1206

Browse files
committed
WIP utility singleLabelConfusionMatrix(List<Prediction<MultiLabel>> predictions)
- partially represent false negatives by calling them false positives tied to some predicted label if there is one
1 parent 35735aa commit b5b1206

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.logging.Level;
2525
import java.util.logging.Logger;
2626
import java.util.stream.Collectors;
27+
import java.util.stream.Stream;
2728

2829
import org.junit.jupiter.api.Assertions;
2930
import org.junit.jupiter.api.BeforeAll;
@@ -125,7 +126,7 @@ public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<M
125126
// intersection(trueLabels, predicted) = true positives
126127
// predicted - trueLabels = false positives
127128
// trueLabels - predicted = false negatives
128-
return predicted.stream().map(pred -> {
129+
return Stream.concat(predicted.stream().map(pred -> {
129130
if (trueLabels.contains(pred)) {
130131
return mkPrediction(pred.getLabel(), pred.getLabel());
131132
} else if (trueLabels.size() == 1) {
@@ -134,7 +135,19 @@ public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<M
134135
// arbitrarily pick first trueLabel
135136
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
136137
}
137-
});
138+
}),
139+
// partially represent false negatives by calling them false positives tied to some predicted label if there is one
140+
trueLabels.stream().filter(t -> !predicted.contains(t)).flatMap(fnTrueLabel -> {
141+
if (predicted.isEmpty()) {
142+
// nothing to pin this on
143+
return Stream.of();
144+
} else if (predicted.size() == 1) {
145+
return Stream.of(mkPrediction(fnTrueLabel.getLabel(), predicted.iterator().next().getLabel()));
146+
} else {
147+
// arbitrarily pick first predicted label
148+
return Stream.of(mkPrediction(fnTrueLabel.getLabel(), predicted.iterator().next().getLabel()));
149+
}
150+
}));
138151
}).collect(Collectors.toList());
139152
}
140153

0 commit comments

Comments
 (0)