Skip to content

Commit 0cbf236

Browse files
authored
Improve MultiLabelConfusionMatrix.toString (#136)
* more readable MultiLabelConfusionMatrix.toString * include label for which there are no predictions in ConfusionMetrics.accuracy log message
1 parent bd2799f commit 0cbf236

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatr
6060
double support = cm.support(label);
6161
// handle div-by-zero
6262
if (support == 0d) {
63-
logger.warning("No predictions: accuracy ill-defined");
63+
logger.warning("No predictions for " + label + ": accuracy ill-defined");
6464
return Double.NaN;
6565
}
6666
return cm.tp(label) / cm.support(label);

MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.List;
2929
import java.util.Set;
3030
import java.util.function.Function;
31+
import java.util.stream.Collectors;
3132

3233
/**
3334
* A {@link ConfusionMatrix} which accepts {@link MultiLabel}s.
@@ -158,15 +159,18 @@ public double confusion(MultiLabel predicted, MultiLabel truth) {
158159

159160
@Override
160161
public String toString() {
161-
StringBuilder sb = new StringBuilder();
162-
sb.append("[");
163-
for (int i = 0; i < mcm.length; i++) {
164-
DenseMatrix cm = mcm[i];
165-
sb.append(cm.toString());
166-
sb.append("\n");
167-
}
168-
sb.append("]");
169-
return sb.toString();
162+
return getDomain().getDomain().stream()
163+
.map(multiLabel -> {
164+
final int tp = (int) tp(multiLabel);
165+
final int fn = (int) fn(multiLabel);
166+
final int fp = (int) fp(multiLabel);
167+
final int tn = (int) tn(multiLabel);
168+
return String.join("\n",
169+
multiLabel.toString(),
170+
String.format(" [tn: %,d fn: %,d]", tn, fn),
171+
String.format(" [fp: %,d tp: %,d]", fp, tp));
172+
}
173+
).collect(Collectors.joining("\n"));
170174
}
171175

172176
static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) {

0 commit comments

Comments
 (0)