diff --git a/docling_ibm_models/tableformer/data_management/tf_predictor.py b/docling_ibm_models/tableformer/data_management/tf_predictor.py index 372d72b..74320d5 100644 --- a/docling_ibm_models/tableformer/data_management/tf_predictor.py +++ b/docling_ibm_models/tableformer/data_management/tf_predictor.py @@ -686,13 +686,20 @@ def predict_dummy( ) if outputs_coord is not None: - bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) - prediction["bboxes"] = bbox_pred.tolist() + if len(outputs_coord) == 0: + prediction["bboxes"] = [] + else: + bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) + prediction["bboxes"] = bbox_pred.tolist() else: prediction["bboxes"] = [] + if outputs_class is not None: - result_class = torch.argmax(outputs_class, dim=1) - prediction["classes"] = result_class.tolist() + if len(outputs_class) == 0: + prediction["classes"] = [] + else: + result_class = torch.argmax(outputs_class, dim=1) + prediction["classes"] = result_class.tolist() else: prediction["classes"] = [] if self._remove_padding: @@ -807,13 +814,20 @@ def predict( ) if outputs_coord is not None: - bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) - prediction["bboxes"] = bbox_pred.tolist() + if len(outputs_coord) == 0: + prediction["bboxes"] = [] + else: + bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) + prediction["bboxes"] = bbox_pred.tolist() else: prediction["bboxes"] = [] + if outputs_class is not None: - result_class = torch.argmax(outputs_class, dim=1) - prediction["classes"] = result_class.tolist() + if len(outputs_class) == 0: + prediction["classes"] = [] + else: + result_class = torch.argmax(outputs_class, dim=1) + prediction["classes"] = result_class.tolist() else: prediction["classes"] = [] if self._remove_padding: diff --git a/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py b/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py index acfacbe..2e047d7 100644 --- a/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +++ b/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py @@ -308,8 +308,12 @@ def predict(self, imgs, max_steps, k, return_attention=False): if len(outputs_coord1) > 0: outputs_coord1 = torch.stack(outputs_coord1) + else: + outputs_coord1 = torch.empty(0) if len(outputs_class1) > 0: outputs_class1 = torch.stack(outputs_class1) + else: + outputs_class1 = torch.empty(0) outputs_class = outputs_class1 outputs_coord = outputs_coord1