Skip to content

Commit

Permalink
fix: safer bbox processing (#27)
Browse files Browse the repository at this point in the history
* Catch word bbox format case

Signed-off-by: Christoph Auer <[email protected]>

* Introducting safety around bboxes and class tensors processing that lead to crashes

Signed-off-by: Maxim Lysak <[email protected]>

---------

Signed-off-by: Christoph Auer <[email protected]>
Signed-off-by: Maxim Lysak <[email protected]>
Co-authored-by: Christoph Auer <[email protected]>
Co-authored-by: Maxim Lysak <[email protected]>
  • Loading branch information
3 people authored Sep 18, 2024
1 parent 7e9758c commit d37272e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,15 @@ def match_cells(self, iocr_page, table_bbox, prediction):
pdf_cells = copy.deepcopy(iocr_page["tokens"])
if len(pdf_cells) > 0:
for word in pdf_cells:
word["bbox"] = [
word["bbox"]["l"],
word["bbox"]["t"],
word["bbox"]["r"],
word["bbox"]["b"],
]
if isinstance(word["bbox"], list):
continue
elif isinstance(word["bbox"], dict):
word["bbox"] = [
word["bbox"]["l"],
word["bbox"]["t"],
word["bbox"]["r"],
word["bbox"]["b"],
]
table_bboxes = prediction["bboxes"]
table_classes = prediction["classes"]
# BBOXES transformed...
Expand Down
30 changes: 22 additions & 8 deletions docling_ibm_models/tableformer/data_management/tf_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,20 @@ def predict_dummy(
)

if outputs_coord is not None:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
if len(outputs_coord) == 0:
prediction["bboxes"] = []
else:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
else:
prediction["bboxes"] = []

if outputs_class is not None:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
if len(outputs_class) == 0:
prediction["classes"] = []
else:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
else:
prediction["classes"] = []
if self._remove_padding:
Expand Down Expand Up @@ -807,13 +814,20 @@ def predict(
)

if outputs_coord is not None:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
if len(outputs_coord) == 0:
prediction["bboxes"] = []
else:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
else:
prediction["bboxes"] = []

if outputs_class is not None:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
if len(outputs_class) == 0:
prediction["classes"] = []
else:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
else:
prediction["classes"] = []
if self._remove_padding:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,12 @@ def predict(self, imgs, max_steps, k, return_attention=False):

if len(outputs_coord1) > 0:
outputs_coord1 = torch.stack(outputs_coord1)
else:
outputs_coord1 = torch.empty(0)
if len(outputs_class1) > 0:
outputs_class1 = torch.stack(outputs_class1)
else:
outputs_class1 = torch.empty(0)

outputs_class = outputs_class1
outputs_coord = outputs_coord1
Expand Down

0 comments on commit d37272e

Please sign in to comment.