diff --git a/magic-pdf.template.json b/magic-pdf.template.json index 114dfce3..cdb3dab6 100644 --- a/magic-pdf.template.json +++ b/magic-pdf.template.json @@ -15,7 +15,7 @@ "enable": true }, "table-config": { - "model": "tablemaster", + "model": "rapid_table", "enable": false, "max_time": 400 }, diff --git a/magic_pdf/libs/Constants.py b/magic_pdf/libs/Constants.py index 0799f6fd..188465e8 100644 --- a/magic_pdf/libs/Constants.py +++ b/magic_pdf/libs/Constants.py @@ -50,4 +50,6 @@ class MODEL_NAME: YOLO_V8_MFD = "yolo_v8_mfd" - UniMerNet_v2_Small = "unimernet_small" \ No newline at end of file + UniMerNet_v2_Small = "unimernet_small" + + RAPID_TABLE = "rapid_table" \ No newline at end of file diff --git a/magic_pdf/libs/config_reader.py b/magic_pdf/libs/config_reader.py index 5e1a300d..b1126b64 100644 --- a/magic_pdf/libs/config_reader.py +++ b/magic_pdf/libs/config_reader.py @@ -92,7 +92,7 @@ def get_table_recog_config(): table_config = config.get('table-config') if table_config is None: logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default") - return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}') + return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}') else: return table_config diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 21a07f5b..48a636f9 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -1,8 +1,6 @@ from loguru import logger import os import time -from pathlib import Path -import shutil from magic_pdf.libs.Constants import * from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.model.model_list import AtomicModel @@ -27,6 +25,7 @@ import unimernet.tasks as tasks from unimernet.processors import load_processor from doclayout_yolo import YOLOv10 + from rapid_table import RapidTable except ImportError as e: logger.exception(e) @@ -51,9 +50,12 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): "device": _device_ } table_model = ppTableModel(config) + elif table_model_type == MODEL_NAME.RAPID_TABLE: + table_model = RapidTable() else: logger.error("table model type not allow") exit(1) + return table_model @@ -226,7 +228,7 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): self.table_config = kwargs.get("table_config") self.apply_table = self.table_config.get("enable", False) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) - self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER) + self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE) # ocr config self.apply_ocr = ocr @@ -281,13 +283,13 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])) ) # 初始化ocr - if self.apply_ocr: - self.ocr_model = atom_model_manager.get_atom_model( - atom_model_name=AtomicModel.OCR, - ocr_show_log=show_log, - det_db_box_thresh=0.3, - lang=self.lang - ) + # if self.apply_ocr: + self.ocr_model = atom_model_manager.get_atom_model( + atom_model_name=AtomicModel.OCR, + ocr_show_log=show_log, + det_db_box_thresh=0.3, + lang=self.lang + ) # init table model if self.apply_table: table_model_dir = self.configs["weights"][self.table_model_name] @@ -451,8 +453,16 @@ def __call__(self, image): table_result = self.table_model.predict(new_image, "html") if len(table_result) > 0: html_code = table_result[0] - else: + elif self.table_model_name == MODEL_NAME.TABLE_MASTER: html_code = self.table_model.img2html(new_image) + elif self.table_model_name == MODEL_NAME.RAPID_TABLE: + new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) + ocr_result = self.ocr_model.ocr(new_image_bgr)[0] + new_ocr_result = [] + for box_ocr_res in ocr_result: + text, score = box_ocr_res[1] + new_ocr_result.append([box_ocr_res[0], text, score]) + html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result) run_time = time.time() - single_table_start_time # logger.info(f"------------table recognition processing ends within {run_time}s-----") diff --git a/magic_pdf/resources/model_config/model_configs.yaml b/magic_pdf/resources/model_config/model_configs.yaml index e56d6ee1..a11f509f 100644 --- a/magic_pdf/resources/model_config/model_configs.yaml +++ b/magic_pdf/resources/model_config/model_configs.yaml @@ -4,4 +4,5 @@ weights: yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt unimernet_small: MFR/unimernet_small struct_eqtable: TabRec/StructEqTable - tablemaster: TabRec/TableMaster \ No newline at end of file + tablemaster: TabRec/TableMaster + rapid_table: TabRec/RapidTable \ No newline at end of file diff --git a/setup.py b/setup.py index 513e349b..416ea309 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ def parse_requirements(filename): "einops", # struct-eqtable依赖 "accelerate", # struct-eqtable依赖 "doclayout_yolo==0.0.2", # doclayout_yolo + "rapid_table", # rapid_table "detectron2" ], },