Skip to content

Commit

Permalink
Merge pull request #915 from myhloli/dev
Browse files Browse the repository at this point in the history
feat(table): add RapidOCR support for RapidTable model
  • Loading branch information
myhloli authored Nov 8, 2024
2 parents 5e0c9d2 + fe2c2c0 commit 5a3872b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
59 changes: 34 additions & 25 deletions magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR

except ImportError as e:
logger.exception(e)
Expand All @@ -42,6 +43,7 @@


def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
ocr_engine = None
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
Expand All @@ -52,11 +54,15 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = ppTableModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTable()
ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
logger.error("table model type not allow")
exit(1)

return table_model
if ocr_engine:
return [table_model, ocr_engine]
else:
return table_model


def mfd_model_init(weight):
Expand Down Expand Up @@ -283,23 +289,32 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
)
# 初始化ocr
# if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)

logger.info('DocAnalysis init done!')

Expand Down Expand Up @@ -381,9 +396,8 @@ def __call__(self, image):
table_res_list.append(res)

if torch.cuda.is_available() and self.device != 'cpu':
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 10:
total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 8:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
Expand Down Expand Up @@ -456,13 +470,8 @@ def __call__(self, image):
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
new_ocr_result = []
for box_ocr_res in ocr_result:
text, score = box_ocr_res[1]
new_ocr_result.append([box_ocr_res[0], text, score])
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
ocr_result, _ = self.ocr_engine(np.asarray(new_image))
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)

run_time = time.time() - single_table_start_time
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def parse_requirements(filename):
"einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle
"rapid_table", # rapid_table
"detectron2"
],
Expand Down

0 comments on commit 5a3872b

Please sign in to comment.