Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(langdetect): simplify language detection model #1453

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/README_Ascend_NPU_Acceleration_zh_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ magic-pdf --help

## 已知问题

- paddleocr使用内嵌onnx模型,仅支持中英文ocr,不支持其他语言ocr
- paddleocr使用内嵌onnx模型,仅在默认语言配置下能以较快速度对中英文进行识别
- 自定义lang参数时,paddleocr速度会存在明显下降情况
- layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型
- 表格解析仅适配了rapid_table模型,其他模型可能会无法使用
1 change: 1 addition & 0 deletions magic_pdf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(self, bits: bytes, lang=None):
logger.info(f"lang: {lang}, detect_lang: {self._lang}")
else:
self._lang = lang
logger.info(f"lang: {lang}")
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
Expand Down
1 change: 1 addition & 0 deletions magic_pdf/model/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ class AtomicModel:
MFR = "mfr"
OCR = "ocr"
Table = "table"
LangDetect = "langdetect"
35 changes: 24 additions & 11 deletions magic_pdf/model/sub_modules/language_detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton


Expand Down Expand Up @@ -59,15 +58,29 @@ def get_text_images(simple_images):
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images)
local_models_dir, device, configs = get_model_config()
# 用yolo11做语言分类
langdetect_model_weights = str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
)
langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
lang = langdetect_model.do_detect(text_images)
return lang
return lang


def model_init(model_name: str):
atom_model_manager = AtomModelSingleton()

if model_name == MODEL_NAME.YOLO_V11_LangDetect:
local_models_dir, device, configs = get_model_config()
model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weight=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
)
),
device=device,
)
else:
raise ValueError(f"model_name {model_name} not found")
return model

15 changes: 10 additions & 5 deletions magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import Counter
from uuid import uuid4

import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
Expand Down Expand Up @@ -83,10 +84,14 @@ def resize_images_to_224(image):


class YOLOv11LangDetModel(object):
def __init__(self, weight, device):
self.model = YOLO(weight)
self.device = device
def __init__(self, langdetect_model_weight, device):

self.model = YOLO(langdetect_model_weight)

if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list):
all_images = []
for image in images:
Expand All @@ -99,15 +104,14 @@ def do_detect(self, images: list):
all_images.append(resize_images_to_224(temp_image))

images_lang_res = self.batch_predict(all_images, batch_size=8)
logger.info(f"images_lang_res: {images_lang_res}")
# logger.info(f"images_lang_res: {images_lang_res}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get)
else:
language = None
return language


def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
Expand All @@ -117,6 +121,7 @@ def predict(self, image):

def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = []

for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
Expand Down
21 changes: 20 additions & 1 deletion magic_pdf/model/sub_modules/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from loguru import logger

from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Expand Down Expand Up @@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model


def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model


def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3,
lang=None,
Expand Down Expand Up @@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
else:
logger.error('layout model name not allow')
exit(1)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
Expand All @@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('device'),
kwargs.get('ocr_engine')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weight'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else:
logger.error('model name not allow')
exit(1)
Expand Down
10 changes: 5 additions & 5 deletions magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
Expand Down Expand Up @@ -94,7 +94,7 @@ def preprocess_image(_image):
ocr_res = []
for img in imgs:
img = preprocess_image(img)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
Expand Down Expand Up @@ -124,7 +124,7 @@ def preprocess_image(_image):
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
Expand All @@ -142,7 +142,7 @@ def __call__(self, img, cls=True, mfd_res=None):

start = time.time()
ori_im = img.copy()
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
Expand Down Expand Up @@ -183,7 +183,7 @@ def __call__(self, img, cls=True, mfd_res=None):
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
if self.use_onnx:
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/resources/model_config/model_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ weights:
struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt
yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt
Loading