Skip to content

Commit

Permalink
Merge pull request #1556 from myhloli/dev
Browse files Browse the repository at this point in the history
feat(table): upgrade RapidTable to1.0.3 and add sub-model support
  • Loading branch information
myhloli authored Jan 16, 2025
2 parents 63c267f + 452a9c0 commit 230191c
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docker/ascend_npu/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ einops
accelerate
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
2 changes: 1 addition & 1 deletion docker/china/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ einops
accelerate
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
2 changes: 1 addition & 1 deletion docker/global/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ einops
accelerate
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
3 changes: 2 additions & 1 deletion magic-pdf.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
},
"table-config": {
"model": "rapid_table",
"sub_model": "slanet_plus",
"enable": true,
"max_time": 400
},
Expand All @@ -39,5 +40,5 @@
"enable": false
}
},
"config_version": "1.1.0"
"config_version": "1.1.1"
}
2 changes: 1 addition & 1 deletion magic_pdf/model/batch_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __call__(self, images: list) -> list:
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = (
html_code, table_cell_bboxes, logic_points, elapse = (
self.model.table_model.predict(new_image)
)
run_time = time.time() - single_table_start_time
Expand Down
4 changes: 3 additions & 1 deletion magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
self.table_sub_model_name = self.table_config.get('sub_model', None)

# ocr config
self.apply_ocr = ocr
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
table_max_time=self.table_max_time,
device=self.device,
ocr_engine=self.ocr_model,
table_sub_model_name=self.table_sub_model_name
)

logger.info('DocAnalysis init done!')
Expand Down Expand Up @@ -276,7 +278,7 @@ def __call__(self, image):
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time
Expand Down
7 changes: 4 additions & 3 deletions magic_pdf/model/sub_modules/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TableMasterPaddleModel


def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
Expand All @@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(ocr_engine)
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else:
logger.error('table model type not allow')
exit(1)
Expand Down Expand Up @@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('ocr_engine')
kwargs.get('ocr_engine'),
kwargs.get('table_sub_model_name')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
Expand Down
29 changes: 23 additions & 6 deletions magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,25 @@
import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType


class RapidTableModel(object):
def __init__(self, ocr_engine):
self.table_model = RapidTable()
def __init__(self, ocr_engine, table_sub_model_name):
sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")

self.table_model = RapidTable(input_args)

# if ocr_engine is None:
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
Expand Down Expand Up @@ -45,7 +58,11 @@ def predict(self, image):
ocr_result = None

if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
table_results = self.table_model(np.asarray(image), ocr_result)
html_code = table_results.pred_html
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
else:
return None, None, None
return None, None, None, None
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def parse_requirements(filename):
"doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime",
"rapid_table==0.3.0", # rapid_table
"rapid_table>=1.0.3,<2.0.0", # rapid_table
"PyYAML", # yaml
"openai", # openai SDK
"detectron2"
Expand Down

0 comments on commit 230191c

Please sign in to comment.