File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -217,10 +217,16 @@ def set_label_names(self, label_names):
217217 self .label_successors = self .label_successors .unsqueeze (0 )
218218
219219 def __call__ (self , preds ):
220+ if preds .shape [1 ] == 0 :
221+ # no labels predicted
222+ return preds
223+ # preds shape: (n_samples, n_labels)
220224 preds_sum_orig = torch .sum (preds )
221225 # step 1: apply implications: for each class, set prediction to max of itself and all successors
222226 preds = preds .unsqueeze (1 )
223227 preds_masked_succ = torch .where (self .label_successors , preds , 0 )
228+ # preds_masked_succ shape: (n_samples, n_labels, n_labels)
229+
224230 preds = preds_masked_succ .max (dim = 2 ).values
225231 if torch .sum (preds ) != preds_sum_orig :
226232 print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
You can’t perform that action at this time.
0 commit comments