Skip to content

Commit

Permalink
Merge pull request #1453 from myhloli/dev
Browse files Browse the repository at this point in the history
refactor(langdetect): simplify language detection model
  • Loading branch information
myhloli authored Jan 9, 2025
2 parents c634e2d + 3271cf7 commit aa53531
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 24 deletions.
3 changes: 2 additions & 1 deletion docs/README_Ascend_NPU_Acceleration_zh_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ magic-pdf --help

## 已知问题

- paddleocr使用内嵌onnx模型,仅支持中英文ocr,不支持其他语言ocr
- paddleocr使用内嵌onnx模型,仅在默认语言配置下能以较快速度对中英文进行识别
- 自定义lang参数时,paddleocr速度会存在明显下降情况
- layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型
- 表格解析仅适配了rapid_table模型,其他模型可能会无法使用
1 change: 1 addition & 0 deletions magic_pdf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(self, bits: bytes, lang=None):
logger.info(f"lang: {lang}, detect_lang: {self._lang}")
else:
self._lang = lang
logger.info(f"lang: {lang}")
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
Expand Down
1 change: 1 addition & 0 deletions magic_pdf/model/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ class AtomicModel:
MFR = "mfr"
OCR = "ocr"
Table = "table"
LangDetect = "langdetect"
35 changes: 24 additions & 11 deletions magic_pdf/model/sub_modules/language_detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton


Expand Down Expand Up @@ -59,15 +58,29 @@ def get_text_images(simple_images):
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config()
# 用yolo11做语言分类
langdetect_model_weights = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
lang = langdetect_model.do_detect(text_images)
return lang
return lang


def model_init(model_name: str):
atom_model_manager = AtomModelSingleton()

if model_name == MODEL_NAME.YOLO_V11_LangDetect:
local_models_dir, device, configs = get_model_config()
model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weight=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
),
device=device,
)
else:
raise ValueError(f"model_name {model_name} not found")
return model

15 changes: 10 additions & 5 deletions magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import Counter
from uuid import uuid4

import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
Expand Down Expand Up @@ -83,10 +84,14 @@ def resize_images_to_224(image):


class YOLOv11LangDetModel(object):
def __init__(self, weight, device):
self.model = YOLO(weight)
self.device = device
def __init__(self, langdetect_model_weight, device):

self.model = YOLO(langdetect_model_weight)

if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list):
all_images = []
for image in images:
Expand All @@ -99,15 +104,14 @@ def do_detect(self, images: list):
all_images.append(resize_images_to_224(temp_image))

images_lang_res = self.batch_predict(all_images, batch_size=8)
logger.info(f"images_lang_res: {images_lang_res}")
# logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get)
else:
language = None
return language


def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
Expand All @@ -117,6 +121,7 @@ def predict(self, image):

def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = []

for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
Expand Down
21 changes: 20 additions & 1 deletion magic_pdf/model/sub_modules/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from loguru import logger

from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Expand Down Expand Up @@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model


def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model


def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3,
lang=None,
Expand Down Expand Up @@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
else:
logger.error('layout model name not allow')
exit(1)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
Expand All @@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('device'),
kwargs.get('ocr_engine')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weight'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else:
logger.error('model name not allow')
exit(1)
Expand Down
10 changes: 5 additions & 5 deletions magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
Expand Down Expand Up @@ -94,7 +94,7 @@ def preprocess_image(_image):
ocr_res = []
for img in imgs:
img = preprocess_image(img)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
Expand Down Expand Up @@ -124,7 +124,7 @@ def preprocess_image(_image):
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
Expand All @@ -142,7 +142,7 @@ def __call__(self, img, cls=True, mfd_res=None):

start = time.time()
ori_im = img.copy()
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
Expand Down Expand Up @@ -183,7 +183,7 @@ def __call__(self, img, cls=True, mfd_res=None):
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/resources/model_config/model_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt

0 comments on commit aa53531

Please sign in to comment.