From d37272eb8e441045dc7c903530db4d1afd051f3e Mon Sep 17 00:00:00 2001 From: Maxim Lysak <101627549+maxmnemonic@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:19:34 +0200 Subject: [PATCH] fix: safer bbox processing (#27) * Catch word bbox format case Signed-off-by: Christoph Auer * Introducting safety around bboxes and class tensors processing that lead to crashes Signed-off-by: Maxim Lysak --------- Signed-off-by: Christoph Auer Signed-off-by: Maxim Lysak Co-authored-by: Christoph Auer Co-authored-by: Maxim Lysak --- .../data_management/tf_cell_matcher.py | 15 ++++++---- .../data_management/tf_predictor.py | 30 ++++++++++++++----- .../models/table04_rs/tablemodel04_rs.py | 4 +++ 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py b/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py index 9025789..2b7aca2 100644 --- a/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +++ b/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py @@ -129,12 +129,15 @@ def match_cells(self, iocr_page, table_bbox, prediction): pdf_cells = copy.deepcopy(iocr_page["tokens"]) if len(pdf_cells) > 0: for word in pdf_cells: - word["bbox"] = [ - word["bbox"]["l"], - word["bbox"]["t"], - word["bbox"]["r"], - word["bbox"]["b"], - ] + if isinstance(word["bbox"], list): + continue + elif isinstance(word["bbox"], dict): + word["bbox"] = [ + word["bbox"]["l"], + word["bbox"]["t"], + word["bbox"]["r"], + word["bbox"]["b"], + ] table_bboxes = prediction["bboxes"] table_classes = prediction["classes"] # BBOXES transformed... diff --git a/docling_ibm_models/tableformer/data_management/tf_predictor.py b/docling_ibm_models/tableformer/data_management/tf_predictor.py index 372d72b..74320d5 100644 --- a/docling_ibm_models/tableformer/data_management/tf_predictor.py +++ b/docling_ibm_models/tableformer/data_management/tf_predictor.py @@ -686,13 +686,20 @@ def predict_dummy( ) if outputs_coord is not None: - bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) - prediction["bboxes"] = bbox_pred.tolist() + if len(outputs_coord) == 0: + prediction["bboxes"] = [] + else: + bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) + prediction["bboxes"] = bbox_pred.tolist() else: prediction["bboxes"] = [] + if outputs_class is not None: - result_class = torch.argmax(outputs_class, dim=1) - prediction["classes"] = result_class.tolist() + if len(outputs_class) == 0: + prediction["classes"] = [] + else: + result_class = torch.argmax(outputs_class, dim=1) + prediction["classes"] = result_class.tolist() else: prediction["classes"] = [] if self._remove_padding: @@ -807,13 +814,20 @@ def predict( ) if outputs_coord is not None: - bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) - prediction["bboxes"] = bbox_pred.tolist() + if len(outputs_coord) == 0: + prediction["bboxes"] = [] + else: + bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord) + prediction["bboxes"] = bbox_pred.tolist() else: prediction["bboxes"] = [] + if outputs_class is not None: - result_class = torch.argmax(outputs_class, dim=1) - prediction["classes"] = result_class.tolist() + if len(outputs_class) == 0: + prediction["classes"] = [] + else: + result_class = torch.argmax(outputs_class, dim=1) + prediction["classes"] = result_class.tolist() else: prediction["classes"] = [] if self._remove_padding: diff --git a/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py b/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py index acfacbe..2e047d7 100644 --- a/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +++ b/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py @@ -308,8 +308,12 @@ def predict(self, imgs, max_steps, k, return_attention=False): if len(outputs_coord1) > 0: outputs_coord1 = torch.stack(outputs_coord1) + else: + outputs_coord1 = torch.empty(0) if len(outputs_class1) > 0: outputs_class1 = torch.stack(outputs_class1) + else: + outputs_class1 = torch.empty(0) outputs_class = outputs_class1 outputs_coord = outputs_coord1