Skip to content

Commit dd6d8f4

Browse files
committed
fix handling of 0-label predictions
1 parent 713b7ae commit dd6d8f4

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

chebai/result/analyse_sem.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,16 @@ def set_label_names(self, label_names):
217217
self.label_successors = self.label_successors.unsqueeze(0)
218218

219219
def __call__(self, preds):
220+
if preds.shape[1] == 0:
221+
# no labels predicted
222+
return preds
223+
# preds shape: (n_samples, n_labels)
220224
preds_sum_orig = torch.sum(preds)
221225
# step 1: apply implications: for each class, set prediction to max of itself and all successors
222226
preds = preds.unsqueeze(1)
223227
preds_masked_succ = torch.where(self.label_successors, preds, 0)
228+
# preds_masked_succ shape: (n_samples, n_labels, n_labels)
229+
224230
preds = preds_masked_succ.max(dim=2).values
225231
if torch.sum(preds) != preds_sum_orig:
226232
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")

0 commit comments

Comments
 (0)