Skip to content

Commit

Permalink
Introducting safety around bboxes and class tensors processing that l…
Browse files Browse the repository at this point in the history
…ead to crashes

Signed-off-by: Maxim Lysak <[email protected]>
  • Loading branch information
Maxim Lysak committed Sep 18, 2024
1 parent 7e9758c commit f4710b6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
30 changes: 22 additions & 8 deletions docling_ibm_models/tableformer/data_management/tf_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,20 @@ def predict_dummy(
)

if outputs_coord is not None:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
if len(outputs_coord) == 0:
prediction["bboxes"] = []
else:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
else:
prediction["bboxes"] = []

if outputs_class is not None:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
if len(outputs_class) == 0:
prediction["classes"] = []
else:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
else:
prediction["classes"] = []
if self._remove_padding:
Expand Down Expand Up @@ -807,13 +814,20 @@ def predict(
)

if outputs_coord is not None:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
if len(outputs_coord) == 0:
prediction["bboxes"] = []
else:
bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
prediction["bboxes"] = bbox_pred.tolist()
else:
prediction["bboxes"] = []

if outputs_class is not None:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
if len(outputs_class) == 0:
prediction["classes"] = []
else:
result_class = torch.argmax(outputs_class, dim=1)
prediction["classes"] = result_class.tolist()
else:
prediction["classes"] = []
if self._remove_padding:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,12 @@ def predict(self, imgs, max_steps, k, return_attention=False):

if len(outputs_coord1) > 0:
outputs_coord1 = torch.stack(outputs_coord1)
else:
outputs_coord1 = torch.empty(0)
if len(outputs_class1) > 0:
outputs_class1 = torch.stack(outputs_class1)
else:
outputs_class1 = torch.empty(0)

outputs_class = outputs_class1
outputs_coord = outputs_coord1
Expand Down

0 comments on commit f4710b6

Please sign in to comment.