diff --git a/docling_ibm_models/tableformer/data_management/tf_predictor.py b/docling_ibm_models/tableformer/data_management/tf_predictor.py index 48e28ac..372d72b 100644 --- a/docling_ibm_models/tableformer/data_management/tf_predictor.py +++ b/docling_ibm_models/tableformer/data_management/tf_predictor.py @@ -524,7 +524,12 @@ def resize_img(self, image, width=None, height=None, inter=cv2.INTER_AREA): return resized, sf def multi_table_predict( - self, iocr_page, table_bboxes, do_matching=True, correct_overlapping_cells=False + self, + iocr_page, + table_bboxes, + do_matching=True, + correct_overlapping_cells=False, + sort_row_col_indexes=True, ): multi_tf_output = [] page_image = iocr_page["image"] @@ -563,56 +568,70 @@ def multi_table_predict( # PROCESS PREDICTED RESULTS, TO TURN PREDICTED COL/ROW IDs into Indexes # Indexes should be in increasing order, without gaps - # Fix col/row indexes - # Arranges all col/row indexes sequentially without gaps using input IDs - - indexing_start_cols = [] # Index of original start col IDs (not indexes) - indexing_end_cols = [] # Index of original end col IDs (not indexes) - indexing_start_rows = [] # Index of original start row IDs (not indexes) - indexing_end_rows = [] # Index of original end row IDs (not indexes) - - # First, collect all possible predicted IDs, to be used as indexes - # ID's returned by Tableformer are sequential, but might contain gaps - for tf_response_cell in tf_responses: - start_col_offset_idx = tf_response_cell["start_col_offset_idx"] - end_col_offset_idx = tf_response_cell["end_col_offset_idx"] - start_row_offset_idx = tf_response_cell["start_row_offset_idx"] - end_row_offset_idx = tf_response_cell["end_row_offset_idx"] - - # Collect all possible col/row IDs: - if start_col_offset_idx not in indexing_start_cols: - indexing_start_cols.append(start_col_offset_idx) - if end_col_offset_idx not in indexing_end_cols: - indexing_end_cols.append(end_col_offset_idx) - if start_row_offset_idx not in indexing_start_rows: - indexing_start_rows.append(start_row_offset_idx) - if end_row_offset_idx not in indexing_end_rows: - indexing_end_rows.append(end_row_offset_idx) - - indexing_start_cols.sort() - indexing_end_cols.sort() - indexing_start_rows.sort() - indexing_end_rows.sort() - - # After this - put actual indexes of IDs back into predicted structure... - for tf_response_cell in tf_responses: - tf_response_cell["start_col_offset_idx"] = indexing_start_cols.index( - tf_response_cell["start_col_offset_idx"] - ) - tf_response_cell["end_col_offset_idx"] = ( - tf_response_cell["start_col_offset_idx"] - + tf_response_cell["col_span"] - ) - tf_response_cell["start_row_offset_idx"] = indexing_start_rows.index( - tf_response_cell["start_row_offset_idx"] - ) - tf_response_cell["end_row_offset_idx"] = ( - tf_response_cell["start_row_offset_idx"] - + tf_response_cell["row_span"] - ) - # Counting matched cols/rows from actual indexes (and not ids) - predict_details["num_cols"] = len(indexing_end_cols) - predict_details["num_rows"] = len(indexing_end_rows) + if sort_row_col_indexes: + # Fix col/row indexes + # Arranges all col/row indexes sequentially without gaps using input IDs + + indexing_start_cols = ( + [] + ) # Index of original start col IDs (not indexes) + indexing_end_cols = [] # Index of original end col IDs (not indexes) + indexing_start_rows = ( + [] + ) # Index of original start row IDs (not indexes) + indexing_end_rows = [] # Index of original end row IDs (not indexes) + + # First, collect all possible predicted IDs, to be used as indexes + # ID's returned by Tableformer are sequential, but might contain gaps + for tf_response_cell in tf_responses: + start_col_offset_idx = tf_response_cell["start_col_offset_idx"] + end_col_offset_idx = tf_response_cell["end_col_offset_idx"] + start_row_offset_idx = tf_response_cell["start_row_offset_idx"] + end_row_offset_idx = tf_response_cell["end_row_offset_idx"] + + # Collect all possible col/row IDs: + if start_col_offset_idx not in indexing_start_cols: + indexing_start_cols.append(start_col_offset_idx) + if end_col_offset_idx not in indexing_end_cols: + indexing_end_cols.append(end_col_offset_idx) + if start_row_offset_idx not in indexing_start_rows: + indexing_start_rows.append(start_row_offset_idx) + if end_row_offset_idx not in indexing_end_rows: + indexing_end_rows.append(end_row_offset_idx) + + indexing_start_cols.sort() + indexing_end_cols.sort() + indexing_start_rows.sort() + indexing_end_rows.sort() + + # After this - put actual indexes of IDs back into predicted structure... + for tf_response_cell in tf_responses: + tf_response_cell["start_col_offset_idx"] = ( + indexing_start_cols.index( + tf_response_cell["start_col_offset_idx"] + ) + ) + tf_response_cell["end_col_offset_idx"] = ( + tf_response_cell["start_col_offset_idx"] + + tf_response_cell["col_span"] + ) + tf_response_cell["start_row_offset_idx"] = ( + indexing_start_rows.index( + tf_response_cell["start_row_offset_idx"] + ) + ) + tf_response_cell["end_row_offset_idx"] = ( + tf_response_cell["start_row_offset_idx"] + + tf_response_cell["row_span"] + ) + # Counting matched cols/rows from actual indexes (and not ids) + predict_details["num_cols"] = len(indexing_end_cols) + predict_details["num_rows"] = len(indexing_end_rows) + else: + otsl_seq = predict_details["prediction"]["rs_seq"] + predict_details["num_cols"] = otsl_seq.index("nl") + predict_details["num_rows"] = otsl_seq.count("nl") + # Put results into multi_tf_output multi_tf_output.append( {"tf_responses": tf_responses, "predict_details": predict_details} diff --git a/tests/test_tf_predictor.py b/tests/test_tf_predictor.py index 9c355c4..f303c8f 100644 --- a/tests/test_tf_predictor.py +++ b/tests/test_tf_predictor.py @@ -512,7 +512,11 @@ def test_tf_predictor(): # List of dicts per table: [{"tf_responses":[...], "predict_details": {}}] multi_tf_output = predictor.multi_table_predict( - iocr_page, table_bboxes, True + iocr_page, + table_bboxes, + do_matching=True, + correct_overlapping_cells=False, + sort_row_col_indexes=True ) # Test output for validity, create visualizations...