Skip to content

Commit

Permalink
fix: Make col/row re-sorting optional on TF predictor (#19) [skip ci]
Browse files Browse the repository at this point in the history
* Made optional col/row re-sorting as a predictor parameter in multi_table_predict

Signed-off-by: Maxim Lysak <[email protected]>

* compute num_cols and num_rows straight from prediction in the TF predict_details in case sort_row_col_indexes == False

Signed-off-by: Maxim Lysak <[email protected]>

---------

Signed-off-by: Maxim Lysak <[email protected]>
Co-authored-by: Maxim Lysak <[email protected]>
Co-authored-by: Christoph Auer <[email protected]>
  • Loading branch information
3 people authored Sep 18, 2024
1 parent f60d3fc commit 7e9758c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 52 deletions.
121 changes: 70 additions & 51 deletions docling_ibm_models/tableformer/data_management/tf_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,12 @@ def resize_img(self, image, width=None, height=None, inter=cv2.INTER_AREA):
return resized, sf

def multi_table_predict(
self, iocr_page, table_bboxes, do_matching=True, correct_overlapping_cells=False
self,
iocr_page,
table_bboxes,
do_matching=True,
correct_overlapping_cells=False,
sort_row_col_indexes=True,
):
multi_tf_output = []
page_image = iocr_page["image"]
Expand Down Expand Up @@ -563,56 +568,70 @@ def multi_table_predict(
# PROCESS PREDICTED RESULTS, TO TURN PREDICTED COL/ROW IDs into Indexes
# Indexes should be in increasing order, without gaps

# Fix col/row indexes
# Arranges all col/row indexes sequentially without gaps using input IDs

indexing_start_cols = [] # Index of original start col IDs (not indexes)
indexing_end_cols = [] # Index of original end col IDs (not indexes)
indexing_start_rows = [] # Index of original start row IDs (not indexes)
indexing_end_rows = [] # Index of original end row IDs (not indexes)

# First, collect all possible predicted IDs, to be used as indexes
# ID's returned by Tableformer are sequential, but might contain gaps
for tf_response_cell in tf_responses:
start_col_offset_idx = tf_response_cell["start_col_offset_idx"]
end_col_offset_idx = tf_response_cell["end_col_offset_idx"]
start_row_offset_idx = tf_response_cell["start_row_offset_idx"]
end_row_offset_idx = tf_response_cell["end_row_offset_idx"]

# Collect all possible col/row IDs:
if start_col_offset_idx not in indexing_start_cols:
indexing_start_cols.append(start_col_offset_idx)
if end_col_offset_idx not in indexing_end_cols:
indexing_end_cols.append(end_col_offset_idx)
if start_row_offset_idx not in indexing_start_rows:
indexing_start_rows.append(start_row_offset_idx)
if end_row_offset_idx not in indexing_end_rows:
indexing_end_rows.append(end_row_offset_idx)

indexing_start_cols.sort()
indexing_end_cols.sort()
indexing_start_rows.sort()
indexing_end_rows.sort()

# After this - put actual indexes of IDs back into predicted structure...
for tf_response_cell in tf_responses:
tf_response_cell["start_col_offset_idx"] = indexing_start_cols.index(
tf_response_cell["start_col_offset_idx"]
)
tf_response_cell["end_col_offset_idx"] = (
tf_response_cell["start_col_offset_idx"]
+ tf_response_cell["col_span"]
)
tf_response_cell["start_row_offset_idx"] = indexing_start_rows.index(
tf_response_cell["start_row_offset_idx"]
)
tf_response_cell["end_row_offset_idx"] = (
tf_response_cell["start_row_offset_idx"]
+ tf_response_cell["row_span"]
)
# Counting matched cols/rows from actual indexes (and not ids)
predict_details["num_cols"] = len(indexing_end_cols)
predict_details["num_rows"] = len(indexing_end_rows)
if sort_row_col_indexes:
# Fix col/row indexes
# Arranges all col/row indexes sequentially without gaps using input IDs

indexing_start_cols = (
[]
) # Index of original start col IDs (not indexes)
indexing_end_cols = [] # Index of original end col IDs (not indexes)
indexing_start_rows = (
[]
) # Index of original start row IDs (not indexes)
indexing_end_rows = [] # Index of original end row IDs (not indexes)

# First, collect all possible predicted IDs, to be used as indexes
# ID's returned by Tableformer are sequential, but might contain gaps
for tf_response_cell in tf_responses:
start_col_offset_idx = tf_response_cell["start_col_offset_idx"]
end_col_offset_idx = tf_response_cell["end_col_offset_idx"]
start_row_offset_idx = tf_response_cell["start_row_offset_idx"]
end_row_offset_idx = tf_response_cell["end_row_offset_idx"]

# Collect all possible col/row IDs:
if start_col_offset_idx not in indexing_start_cols:
indexing_start_cols.append(start_col_offset_idx)
if end_col_offset_idx not in indexing_end_cols:
indexing_end_cols.append(end_col_offset_idx)
if start_row_offset_idx not in indexing_start_rows:
indexing_start_rows.append(start_row_offset_idx)
if end_row_offset_idx not in indexing_end_rows:
indexing_end_rows.append(end_row_offset_idx)

indexing_start_cols.sort()
indexing_end_cols.sort()
indexing_start_rows.sort()
indexing_end_rows.sort()

# After this - put actual indexes of IDs back into predicted structure...
for tf_response_cell in tf_responses:
tf_response_cell["start_col_offset_idx"] = (
indexing_start_cols.index(
tf_response_cell["start_col_offset_idx"]
)
)
tf_response_cell["end_col_offset_idx"] = (
tf_response_cell["start_col_offset_idx"]
+ tf_response_cell["col_span"]
)
tf_response_cell["start_row_offset_idx"] = (
indexing_start_rows.index(
tf_response_cell["start_row_offset_idx"]
)
)
tf_response_cell["end_row_offset_idx"] = (
tf_response_cell["start_row_offset_idx"]
+ tf_response_cell["row_span"]
)
# Counting matched cols/rows from actual indexes (and not ids)
predict_details["num_cols"] = len(indexing_end_cols)
predict_details["num_rows"] = len(indexing_end_rows)
else:
otsl_seq = predict_details["prediction"]["rs_seq"]
predict_details["num_cols"] = otsl_seq.index("nl")
predict_details["num_rows"] = otsl_seq.count("nl")

# Put results into multi_tf_output
multi_tf_output.append(
{"tf_responses": tf_responses, "predict_details": predict_details}
Expand Down
6 changes: 5 additions & 1 deletion tests/test_tf_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,11 @@ def test_tf_predictor():
# List of dicts per table: [{"tf_responses":[...], "predict_details": {}}]

multi_tf_output = predictor.multi_table_predict(
iocr_page, table_bboxes, True
iocr_page,
table_bboxes,
do_matching=True,
correct_overlapping_cells=False,
sort_row_col_indexes=True
)

# Test output for validity, create visualizations...
Expand Down

0 comments on commit 7e9758c

Please sign in to comment.